使用OneHotEncoderEstimator操作Affairs
生活随笔
收集整理的這篇文章主要介紹了
使用OneHotEncoderEstimator操作Affairs
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
2019獨角獸企業重金招聘Python工程師標準>>>
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoderEstimator, VectorAssembler} import org.apache.spark.ml.Pipeline import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoostClassificationModel} import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.ml.PipelineModelval data = (spark.read.format("csv").option("sep", ",").option("inferSchema", "true").option("header", "true").load("/user/spark/H2O/Affairs.csv"))data.createOrReplaceTempView("res1")val affairs = "case when affairs>0 then 1 else 0 end as affairs," val df = (spark.sql("select " + affairs +"gender,age,yearsmarried,children,religiousness,education,occupation,rating" +" from res1 "))val categoricals = df.dtypes.filter(_._2 == "StringType") map (_._1) val indexers = categoricals.map(c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx").setHandleInvalid("keep") )val encoders = categoricals.map(c => new OneHotEncoderEstimator().setInputCols(Array(s"${c}_idx")).setOutputCols(Array(s"${c}_enc")).setDropLast(false).setHandleInvalid("keep") )val colArray_enc = categoricals.map(x => x + "_enc") val colArray_numeric = df.dtypes.filter(_._2 != "StringType") map (_._1)val final_colArray = (colArray_numeric ++ colArray_enc).filter(!_.contains("affairs")) val vectorAssembler = new VectorAssembler().setInputCols(final_colArray).setOutputCol("features")/* val test_pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler)) test_pipeline.fit(df).transform(df) *//// // Create an XGBoost Classifier val xgb = new XGBoostEstimator(Map("num_class" -> 2, "num_rounds" -> 5, "objective" -> "binary:logistic", "booster" -> "gbtree")).setLabelCol("affairs").setFeaturesCol("features")// XGBoost paramater grid val xgbParamGrid = (new ParamGridBuilder().addGrid(xgb.round, Array(10)).addGrid(xgb.maxDepth, Array(10,20)).addGrid(xgb.minChildWeight, Array(0.1)).addGrid(xgb.gamma, Array(0.1)).addGrid(xgb.subSample, Array(0.8)).addGrid(xgb.colSampleByTree, Array(0.90)).addGrid(xgb.alpha, Array(0.0)).addGrid(xgb.lambda, Array(0.6)).addGrid(xgb.scalePosWeight, Array(0.1)).addGrid(xgb.eta, Array(0.4)).addGrid(xgb.boosterType, Array("gbtree")).addGrid(xgb.objective, Array("binary:logistic")) .build())// Create the XGBoost pipeline val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler, xgb))// Setup the binary classifier evaluator val evaluator = (new BinaryClassificationEvaluator().setLabelCol("affairs").setRawPredictionCol("prediction").setMetricName("areaUnderROC"))// Create the Cross Validation pipeline, using XGBoost as the estimator, the // Binary Classification evaluator, and xgbParamGrid for hyperparameters val cv = (new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(xgbParamGrid).setNumFolds(3).setSeed(0))val Array(trainingData, testData) = df.randomSplit(Array(0.8, 0.2), seed=0) // Create the model by fitting the training data val xgbModel = cv.fit(trainingData)// Test the data by scoring the model val results = xgbModel.transform(testData)/*scala> results.select("affairs", "probabilities", "prediction").show(false) +-------+-----------------------------------------+----------+ |affairs|probabilities |prediction| +-------+-----------------------------------------+----------+ |0 |[0.9525144696235657,0.04748553782701492] |0.0 | |0 |[0.9776982069015503,0.02230178564786911] |0.0 | |0 |[0.968203604221344,0.031796377152204514] |0.0 | |0 |[0.9699327945709229,0.03006718121469021] |0.0 | |0 |[0.976881742477417,0.023118266835808754] |0.0 | |0 |[0.9741477966308594,0.025852231308817863]|0.0 | |0 |[0.9741477966308594,0.025852231308817863]|0.0 | |0 |[0.9775936603546143,0.022406354546546936]|0.0 | |0 |[0.9776982069015503,0.02230178564786911] |0.0 | |0 |[0.9775936603546143,0.022406354546546936]|0.0 | |0 |[0.9776982069015503,0.02230178564786911] |0.0 | |0 |[0.9720195531845093,0.02798045612871647] |0.0 | |0 |[0.9693607091903687,0.0306392814964056] |0.0 | |0 |[0.976881742477417,0.023118266835808754] |0.0 | |0 |[0.9646676778793335,0.035332340747117996]|0.0 | |0 |[0.9624955654144287,0.03750446066260338] |0.0 | |0 |[0.966502845287323,0.03349713608622551] |0.0 | |0 |[0.9776982069015503,0.02230178564786911] |0.0 | |0 |[0.9636635184288025,0.03633648902177811] |0.0 | |0 |[0.9696801900863647,0.030319783836603165]|0.0 | +-------+-----------------------------------------+----------+ only showing top 20 rows*/results.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show() /*scala> results.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show() +------+--------+----------+------------+-------------+-------------+ |gender|children|gender_idx|children_idx| gender_enc| children_enc| +------+--------+----------+------------+-------------+-------------+ |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| yes| 0.0| 0.0|(2,[0],[1.0])|(2,[0],[1.0])| |female| yes| 0.0| 0.0|(2,[0],[1.0])|(2,[0],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| yes| 0.0| 0.0|(2,[0],[1.0])|(2,[0],[1.0])| |female| yes| 0.0| 0.0|(2,[0],[1.0])|(2,[0],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| no| 0.0| 1.0|(2,[0],[1.0])|(2,[1],[1.0])| |female| yes| 0.0| 0.0|(2,[0],[1.0])|(2,[0],[1.0])| +------+--------+----------+------------+-------------+-------------+ only showing top 20 rows*/轉載于:https://my.oschina.net/kyo4321/blog/2994654
與50位技術專家面對面20年技術見證,附贈技術全景圖總結
以上是生活随笔為你收集整理的使用OneHotEncoderEstimator操作Affairs的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 2019 CES展即将开启 思岚科技将会
- 下一篇: bzoj4589: Hard Nim