标签:mlu 使用 eee learn open 参数 如何 gif dict
转载请标明出处http://www.cnblogs.com/haozhengfei/p/8b9cb1875288d9f6cfc2f5a9b2f10eac.html
决策树 | 分类决策树 | 用于分类标签值,如晴天/阴天/雾/雨、用户性别、网页是否是垃圾页面。 |
回归决策树 | 预测实数值,如明天的温度、用户的年龄、网页的相关程度 |
强调:回归决策树的结果(数值)加减是有意义的,但是分类决策树是没有意义的,因为它是类别 |
1 import org.apache.log4j.{Level, Logger} 2 import org.apache.spark.mllib.feature.{StandardScaler, StandardScalerModel} 3 import org.apache.spark.mllib.regression.LabeledPoint 4 import org.apache.spark.mllib.tree.{GradientBoostedTrees, DecisionTree} 5 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} 6 import org.apache.spark.mllib.tree.impurity.Entropy 7 import org.apache.spark.mllib.util.MLUtils 8 import org.apache.spark.rdd.RDD 9 import org.apache.spark.{SparkConf, SparkContext} 10 11 /** 12 * Created by hzf 13 */ 14 object GBDT_new { 15 // E:\IDEA_Projects\mlib\data\GBDT\train E:\IDEA_Projects\mlib\data\GBDT\train\model 10 local 16 def main(args: Array[String]) { 17 Logger.getLogger("org.apache.spark").setLevel(Level.ERROR) 18 if (args.length < 4) { 19 System.err.println("Usage: DecisionTrees <inputPath> <modelPath> <maxDepth> <master> [<AppName>]") 20 System.err.println("eg: hdfs://192.168.57.104:8020/user/000000_0 10 0.1 spark://192.168.57.104:7077 DecisionTrees") 21 System.exit(1) 22 } 23 val appName = if (args.length > 4) args(4) else "DecisionTrees" 24 val conf = new SparkConf().setAppName(appName).setMaster(args(3)) 25 val sc = new SparkContext(conf) 26 27 val traindata: RDD[LabeledPoint] = MLUtils.loadLabeledPoints(sc, args(0)) 28 val features = traindata.map(_.features) 29 val scaler: StandardScalerModel = new StandardScaler(withMean = true, withStd = true).fit(features) 30 val train: RDD[LabeledPoint] = traindata.map(sample => { 31 val label = sample.label 32 val feature = scaler.transform(sample.features) 33 new LabeledPoint(label, feature) 34 }) 35 val splitRdd: Array[RDD[LabeledPoint]] = traindata.randomSplit(Array(1.0, 9.0)) 36 val testData: RDD[LabeledPoint] = splitRdd(0) 37 val realTrainData: RDD[LabeledPoint] = splitRdd(1) 38 39 val boostingStrategy: BoostingStrategy = BoostingStrategy.defaultParams("Classification") 40 boostingStrategy.setNumIterations(3) 41 boostingStrategy.treeStrategy.setNumClasses(2) 42 boostingStrategy.treeStrategy.setMaxDepth(args(2).toInt) 43 boostingStrategy.setLearningRate(0.8) 44 // boostingStrategy.treeStrategy.setCategoricalFeaturesInfo(Map[Int, Int]()) 45 val model = GradientBoostedTrees.train(realTrainData, boostingStrategy) 46 47 val labelAndPreds = testData.map(point => { 48 val prediction = model.predict(point.features) 49 (point.label, prediction) 50 }) 51 val acc = labelAndPreds.filter(r => r._1 == r._2).count.toDouble / testData.count() 52 53 println("Test Error = " + acc) 54 55 model.save(sc, args(1)) 56 } 57 }
E:\IDEA_Projects\mlib\data\GBDT\train E:\IDEA_Projects\mlib\data\GBDT\train\model 10 local
标签:mlu 使用 eee learn open 参数 如何 gif dict
原文地址:http://www.cnblogs.com/haozhengfei/p/8b9cb1875288d9f6cfc2f5a9b2f10eac.html