關鍵字:spark、mllib、Gradient-Boosted Trees、廣告點擊預測
本文嘗試使用Spark提供的機器學習算法 Gradient-Boosted Trees來預測一個用戶是否會點擊廣告。
訓練和測試數據使用Kaggle Avazu CTR 比賽的樣例數據,下載地址:https://www.kaggle.com/c/avazu-ctr-prediction/data
數據格式如下:
包含24個字段:
- 1-id: ad identifier
- 2-click: 0/1 for non-click/click
- 3-hour: format is YYMMDDHH, so?14091123 means 23:00 on Sept. 11, 2014 UTC.
- 4-C1 —?anonymized categorical variable
- 5-banner_pos
- 6-site_id
- 7-site_domain
- 8-site_category
- 9-app_id
- 10-app_domain
- 11-app_category
- 12-device_id
- 13-device_ip
- 14-device_model
- 15-device_type
- 16-device_conn_type
- 17~24—C14-C21 — anonymized categorical variables
其中5到15列為分類特征,16~24列為數值型特征。
Spark代碼如下:
package com.lxw1234.test?import scala.collection.mutable.ListBufferimport scala.collection.mutable.ArrayBuffer?import org.apache.spark.SparkContextimport org.apache.spark.SparkContext._import org.apache.spark.SparkConfimport org.apache.spark.rdd.RDD?import org.apache.spark.mllib.classification.NaiveBayesimport org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.mllib.linalg.Vectors?import org.apache.spark.mllib.tree.GradientBoostedTreesimport org.apache.spark.mllib.tree.configuration.BoostingStrategyimport org.apache.spark.mllib.tree.model.GradientBoostedTreesModel?/**?* By: lxw?* http://lxw1234.com?*/object CtrPredict {?? //input (1fbe01fe,f3845767,28905ebd,ecad2386,7801e8d9)? //output ((0:1fbe01fe),(1:f3845767),(2:28905ebd),(3:ecad2386),(4:7801e8d9))??? def parseCatFeatures(catfeatures: Array[String]) :? List[(Int, String)] = {????? var catfeatureList = new ListBuffer[(Int, String)]()????? for (i <- 0 until catfeatures.length){????????? catfeatureList += i -> catfeatures(i).toString????? }????? catfeatureList.toList??? }?? def main(args: Array[String]) {????? val conf = new SparkConf().setMaster("yarn-client")????? val sc = new SparkContext(conf)???? ?????? var ctrRDD = sc.textFile("/tmp/lxw1234/sample.txt",10);????? println("Total records : " + ctrRDD.count)???? ?????? //將整個數據集80%作為訓練數據,20%作為測試數據集????? var train_test_rdd = ctrRDD.randomSplit(Array(0.8, 0.2), seed = 37L)????? var train_raw_rdd = train_test_rdd(0)????? var test_raw_rdd = train_test_rdd(1)???? ?????? println("Train records : " + train_raw_rdd.count)????? println("Test records : " + test_raw_rdd.count)???? ?????? //cache train, test????? train_raw_rdd.cache()????? test_raw_rdd.cache()???? ?????? var train_rdd = train_raw_rdd.map{ line =>????????? var tokens = line.split(",",-1)????????? //key為id和是否點擊廣告????????? var catkey = tokens(0) + "::" + tokens(1)????????? //第6列到第15列為分類特征,需要One-Hot-Encoding????????? var catfeatures = tokens.slice(5, 14)????????? //第16列到24列為數值特征,直接使用????????? var numericalfeatures = tokens.slice(15, tokens.size-1)????????? (catkey, catfeatures, numericalfeatures)????? }???? ?????? //拿一條出來看看????? train_rdd.take(1)????? //scala> train_rdd.take(1)????? //res6: Array[(String, Array[String], Array[String])] = Array((1000009418151094273::0,Array(1fbe01fe,????? //??????????? f3845767, 28905ebd, ecad2386, 7801e8d9, 07d7df22, a99f214a, ddd2926e, 44956a24),????? //????????????? Array(2, 15706, 320, 50, 1722, 0, 35, -1)))???? ?????? //將分類特征先做特征ID映射????? var train_cat_rdd? = train_rdd.map{??????? x => parseCatFeatures(x._2)????? }???? ?????? train_cat_rdd.take(1)????? //scala> train_cat_rdd.take(1)????? //res12: Array[List[(Int, String)]] = Array(List((0,1fbe01fe), (1,f3845767), (2,28905ebd),????? //??????? (3,ecad2386), (4,7801e8d9), (5,07d7df22), (6,a99f214a), (7,ddd2926e), (8,44956a24)))???? ?????? //將train_cat_rdd中的(特征ID:特征)去重,并進行編號????? var oheMap = train_cat_rdd.flatMap(x => x).distinct().zipWithIndex().collectAsMap()????? //oheMap: scala.collection.Map[(Int, String),Long] = Map((7,608511e9) -> 31527, (7,b2d8fbed) -> 42207,????? //? (7,1d3e2fdb) -> 52791????? println("Number of features")????? println(oheMap.size)???? ?????? //create OHE for train data????? var ohe_train_rdd = train_rdd.map{ case (key, cateorical_features, numerical_features) =>????????????? var cat_features_indexed = parseCatFeatures(cateorical_features)?????????????????????? ?????????????? var cat_feature_ohe = new ArrayBuffer[Double]????????????? for (k <- cat_features_indexed) {??????????????? if(oheMap contains k){??????????????? cat_feature_ohe += (oheMap get (k)).get.toDouble??????????????? }else {????????????????? cat_feature_ohe += 0.0??????????????? }????????????? ?????????????? }????????????? var numerical_features_dbl? = numerical_features.map{??????????????????????? x =>????????????????????????? var x1 = if (x.toInt < 0) "0" else x??????????????????????? x1.toDouble????????????? }????????????? var features = cat_feature_ohe.toArray ++? numerical_features_dbl????????? ?????????????? LabeledPoint(key.split("::")(1).toInt, Vectors.dense(features))????????????????????????????????????????????? ????? }???? ????? ohe_train_rdd.take(1)???? //res15: Array[org.apache.spark.mllib.regression.LabeledPoint] =???? //? Array((0.0,[43127.0,50023.0,57445.0,13542.0,31092.0,14800.0,23414.0,54121.0,???? //???? 17554.0,2.0,15706.0,320.0,50.0,1722.0,0.0,35.0,0.0]))??? ????? //訓練模型???? //val boostingStrategy = BoostingStrategy.defaultParams("Regression")???? val boostingStrategy = BoostingStrategy.defaultParams("Classification")???? boostingStrategy.numIterations = 100???? boostingStrategy.treeStrategy.numClasses = 2???? boostingStrategy.treeStrategy.maxDepth = 10???? boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()??? ???? ????? val model = GradientBoostedTrees.train(ohe_train_rdd, boostingStrategy)???? //保存模型???? model.save(sc, "/tmp/myGradientBoostingClassificationModel")???? //加載模型???? val sameModel = GradientBoostedTreesModel.load(sc,"/tmp/myGradientBoostingClassificationModel")??? ????? //將測試數據集做OHE???? var test_rdd = test_raw_rdd.map{ line =>??????? var tokens = line.split(",")??????? var catkey = tokens(0) + "::" + tokens(1)??????? var catfeatures = tokens.slice(5, 14)??????? var numericalfeatures = tokens.slice(15, tokens.size-1)??????? (catkey, catfeatures, numericalfeatures)???? }??? ????? var ohe_test_rdd = test_rdd.map{ case (key, cateorical_features, numerical_features) =>??????????? var cat_features_indexed = parseCatFeatures(cateorical_features)???? ???????????? var cat_feature_ohe = new ArrayBuffer[Double]??????????? for (k <- cat_features_indexed) {????????????? ?????????????? if(oheMap contains k){??????????????? cat_feature_ohe += (oheMap get (k)).get.toDouble????????????? }else {??????????????? cat_feature_ohe += 0.0????????????? }??????????? }????????? var numerical_features_dbl? = numerical_features.map{x =>????????????????????????????? var x1 = if (x.toInt < 0) "0" else x????????????????????????????? x1.toDouble}??????????? var features = cat_feature_ohe.toArray ++? numerical_features_dbl????????? ???????????? LabeledPoint(key.split("::")(1).toInt, Vectors.dense(features))????????????????????????????????????????????? ????? }??? ????? //驗證測試數據集???? var b = ohe_test_rdd.map {??????? y => var s = model.predict(y.features)??????? (s,y.label,y.features)???? }??? ????? b.take(10).foreach(println)??? ????? //預測準確率????? var predictions = ohe_test_rdd.map(lp => sameModel.predict(lp.features))????? predictions.take(10).foreach(println)????? var predictionAndLabel = predictions.zip( ohe_test_rdd.map(_.label))????? var accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2 ).count/ohe_test_rdd.count????? println("GBTR accuracy " + accuracy)????? //GBTR accuracy 0.8227084119200302?? ?? }?}?
其中,訓練數據集: Train records : 104558, 測試數據集:Test records : 26510
程序主要輸出:
scala> train_rdd.take(1)res23: Array[(String, Array[String], Array[String])] = Array((1000009418151094273::0,Array(1fbe01fe, f3845767, 28905ebd, ecad2386, 7801e8d9, 07d7df22, a99f214a, ddd2926e, 44956a24),Array(2, 15706, 320, 50, 1722, 0, 35, -1)))??scala> train_cat_rdd.take(1)res24: Array[List[(Int, String)]] = Array(List((0,1fbe01fe), (1,f3845767), (2,28905ebd),(3,ecad2386), (4,7801e8d9), (5,07d7df22), (6,a99f214a), (7,ddd2926e), (8,44956a24)))??scala> println("Number of features")Number of features?scala> println(oheMap.size)57606??scala> ohe_train_rdd.take(1)res27: Array[org.apache.spark.mllib.regression.LabeledPoint] = Array((0.0,[11602.0,22813.0,11497.0,16828.0,30657.0,23893.0,13182.0,31723.0,39722.0,2.0,15706.0,320.0,50.0,1722.0,0.0,35.0,0.0]))??scala> println("GBTR accuracy " + accuracy)GBTR accuracy 0.8227084119200302??
?
總結
以上是生活随笔為你收集整理的Spark MLlib实现的广告点击预测–Gradient-Boosted Trees的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。