码迷,mamicode.com
首页 > 编程语言 > 详细

梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

时间:2017-11-06 11:05:44      阅读:714      评论:0      收藏:0      [点我收藏+]

标签:training   exe   lines   number   orm   深度   数据预处理   for   lua   

梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

http://blog.csdn.net/liulingyuan6/article/details/53426350

 

梯度迭代树

算法简介:

        梯度提升树是一种决策树的集成算法。它通过反复迭代训练决策树来最小化损失函数。决策树类似,梯度提升树具有可处理类别特征、易扩展到多分类问题、不需特征缩放等性质。Spark.ml通过使用现有decision tree工具来实现。

       梯度提升树依次迭代训练一系列的决策树。在一次迭代中,算法使用现有的集成来对每个训练实例的类别进行预测,然后将预测结果与真实的标签值进行比较。通过重新标记,来赋予预测结果不好的实例更高的权重。所以,在下次迭代中,决策树会对先前的错误进行修正。

       对实例标签进行重新标记的机制由损失函数来指定。每次迭代过程中,梯度迭代树在训练数据上进一步减少损失函数的值。spark.ml为分类问题提供一种损失函数(Log Loss),为回归问题提供两种损失函数(平方误差与绝对误差)。

       Spark.ml支持二分类以及回归的随机森林算法,适用于连续特征以及类别特征。

*注意梯度提升树目前不支持多分类问题。

参数:

checkpointInterval:

类型:整数型。

含义:设置检查点间隔(>=1),或不设置检查点(-1)。

featuresCol:

类型:字符串型。

含义:特征列名。

impurity:

类型:字符串型。

含义:计算信息增益的准则(不区分大小写)。

labelCol:

类型:字符串型。

含义:标签列名。

lossType:

类型:字符串型。

含义:损失函数类型。

maxBins:

类型:整数型。

含义:连续特征离散化的最大数量,以及选择每个节点分裂特征的方式。

maxDepth:

类型:整数型。

含义:树的最大深度(>=0)。

maxIter:

类型:整数型。

含义:迭代次数(>=0)。

minInfoGain:

类型:双精度型。

含义:分裂节点时所需最小信息增益。

minInstancesPerNode:

类型:整数型。

含义:分裂后自节点最少包含的实例数量。

predictionCol:

类型:字符串型。

含义:预测结果列名。

rawPredictionCol:

类型:字符串型。

含义:原始预测。

seed:

类型:长整型。

含义:随机种子。

subsamplingRate:

类型:双精度型。

含义:学习一棵决策树使用的训练数据比例,范围[0,1]。

stepSize:

类型:双精度型。

含义:每次迭代优化步长。

示例:

       下面的例子导入LibSVM格式数据,并将之划分为训练数据和测试数据。使用第一部分数据进行训练,剩下数据来测试。训练之前我们使用了两种数据预处理方法来对特征进行转换,并且添加了元数据到DataFrame。

Scala:

 

[plain] view plain copy
 
  1. import org.apache.spark.ml.Pipeline  
  2. import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}  
  3. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator  
  4. import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}  
  5.   
  6. // Load and parse the data file, converting it to a DataFrame.  
  7. val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")  
  8.   
  9. // Index labels, adding metadata to the label column.  
  10. // Fit on whole dataset to include all labels in index.  
  11. val labelIndexer = new StringIndexer()  
  12.   .setInputCol("label")  
  13.   .setOutputCol("indexedLabel")  
  14.   .fit(data)  
  15. // Automatically identify categorical features, and index them.  
  16. // Set maxCategories so features with > 4 distinct values are treated as continuous.  
  17. val featureIndexer = new VectorIndexer()  
  18.   .setInputCol("features")  
  19.   .setOutputCol("indexedFeatures")  
  20.   .setMaxCategories(4)  
  21.   .fit(data)  
  22.   
  23. // Split the data into training and test sets (30% held out for testing).  
  24. val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))  
  25.   
  26. // Train a GBT model.  
  27. val gbt = new GBTClassifier()  
  28.   .setLabelCol("indexedLabel")  
  29.   .setFeaturesCol("indexedFeatures")  
  30.   .setMaxIter(10)  
  31.   
  32. // Convert indexed labels back to original labels.  
  33. val labelConverter = new IndexToString()  
  34.   .setInputCol("prediction")  
  35.   .setOutputCol("predictedLabel")  
  36.   .setLabels(labelIndexer.labels)  
  37.   
  38. // Chain indexers and GBT in a Pipeline.  
  39. val pipeline = new Pipeline()  
  40.   .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))  
  41.   
  42. // Train model. This also runs the indexers.  
  43. val model = pipeline.fit(trainingData)  
  44.   
  45. // Make predictions.  
  46. val predictions = model.transform(testData)  
  47.   
  48. // Select example rows to display.  
  49. predictions.select("predictedLabel", "label", "features").show(5)  
  50.   
  51. // Select (prediction, true label) and compute test error.  
  52. val evaluator = new MulticlassClassificationEvaluator()  
  53.   .setLabelCol("indexedLabel")  
  54.   .setPredictionCol("prediction")  
  55.   .setMetricName("accuracy")  
  56. val accuracy = evaluator.evaluate(predictions)  
  57. println("Test Error = " + (1.0 - accuracy))  
  58.   
  59. val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]  
  60. println("Learned classification GBT model:\n" + gbtModel.toDebugString)  


Java:

 

 

[java] view plain copy
 
  1. import org.apache.spark.ml.Pipeline;  
  2. import org.apache.spark.ml.PipelineModel;  
  3. import org.apache.spark.ml.PipelineStage;  
  4. import org.apache.spark.ml.classification.GBTClassificationModel;  
  5. import org.apache.spark.ml.classification.GBTClassifier;  
  6. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;  
  7. import org.apache.spark.ml.feature.*;  
  8. import org.apache.spark.sql.Dataset;  
  9. import org.apache.spark.sql.Row;  
  10. import org.apache.spark.sql.SparkSession;  
  11.   
  12. // Load and parse the data file, converting it to a DataFrame.  
  13. Dataset<Row> data = spark  
  14.   .read()  
  15.   .format("libsvm")  
  16.   .load("data/mllib/sample_libsvm_data.txt");  
  17.   
  18. // Index labels, adding metadata to the label column.  
  19. // Fit on whole dataset to include all labels in index.  
  20. StringIndexerModel labelIndexer = new StringIndexer()  
  21.   .setInputCol("label")  
  22.   .setOutputCol("indexedLabel")  
  23.   .fit(data);  
  24. // Automatically identify categorical features, and index them.  
  25. // Set maxCategories so features with > 4 distinct values are treated as continuous.  
  26. VectorIndexerModel featureIndexer = new VectorIndexer()  
  27.   .setInputCol("features")  
  28.   .setOutputCol("indexedFeatures")  
  29.   .setMaxCategories(4)  
  30.   .fit(data);  
  31.   
  32. // Split the data into training and test sets (30% held out for testing)  
  33. Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});  
  34. Dataset<Row> trainingData = splits[0];  
  35. Dataset<Row> testData = splits[1];  
  36.   
  37. // Train a GBT model.  
  38. GBTClassifier gbt = new GBTClassifier()  
  39.   .setLabelCol("indexedLabel")  
  40.   .setFeaturesCol("indexedFeatures")  
  41.   .setMaxIter(10);  
  42.   
  43. // Convert indexed labels back to original labels.  
  44. IndexToString labelConverter = new IndexToString()  
  45.   .setInputCol("prediction")  
  46.   .setOutputCol("predictedLabel")  
  47.   .setLabels(labelIndexer.labels());  
  48.   
  49. // Chain indexers and GBT in a Pipeline.  
  50. Pipeline pipeline = new Pipeline()  
  51.   .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter});  
  52.   
  53. // Train model. This also runs the indexers.  
  54. PipelineModel model = pipeline.fit(trainingData);  
  55.   
  56. // Make predictions.  
  57. Dataset<Row> predictions = model.transform(testData);  
  58.   
  59. // Select example rows to display.  
  60. predictions.select("predictedLabel", "label", "features").show(5);  
  61.   
  62. // Select (prediction, true label) and compute test error.  
  63. MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()  
  64.   .setLabelCol("indexedLabel")  
  65.   .setPredictionCol("prediction")  
  66.   .setMetricName("accuracy");  
  67. double accuracy = evaluator.evaluate(predictions);  
  68. System.out.println("Test Error = " + (1.0 - accuracy));  
  69.   
  70. GBTClassificationModel gbtModel = (GBTClassificationModel)(model.stages()[2]);  
  71. System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());  


Python:

 

 

[python] view plain copy
 
    1. from pyspark.ml import Pipeline  
    2. from pyspark.ml.classification import GBTClassifier  
    3. from pyspark.ml.feature import StringIndexer, VectorIndexer  
    4. from pyspark.ml.evaluation import MulticlassClassificationEvaluator  
    5.   
    6. # Load and parse the data file, converting it to a DataFrame.  
    7. data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")  
    8.   
    9. # Index labels, adding metadata to the label column.  
    10. # Fit on whole dataset to include all labels in index.  
    11. labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)  
    12. # Automatically identify categorical features, and index them.  
    13. # Set maxCategories so features with > 4 distinct values are treated as continuous.  
    14. featureIndexer =\  
    15.     VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)  
    16.   
    17. # Split the data into training and test sets (30% held out for testing)  
    18. (trainingData, testData) = data.randomSplit([0.7, 0.3])  
    19.   
    20. # Train a GBT model.  
    21. gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10)  
    22.   
    23. # Chain indexers and GBT in a Pipeline  
    24. pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])  
    25.   
    26. # Train model.  This also runs the indexers.  
    27. model = pipeline.fit(trainingData)  
    28.   
    29. # Make predictions.  
    30. predictions = model.transform(testData)  
    31.   
    32. # Select example rows to display.  
    33. predictions.select("prediction", "indexedLabel", "features").show(5)  
    34.   
    35. # Select (prediction, true label) and compute test error  
    36. evaluator = MulticlassClassificationEvaluator(  
    37.     labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")  
    38. accuracy = evaluator.evaluate(predictions)  
    39. print("Test Error = %g" % (1.0 - accuracy))  
    40.   
    41. gbtModel = model.stages[2]  
    42. print(gbtModel)  # summary only  

梯度迭代树(GBDT)算法原理及Spark MLlib调用实例(Scala/Java/python)

标签:training   exe   lines   number   orm   深度   数据预处理   for   lua   

原文地址:http://www.cnblogs.com/zhangbojiangfeng/p/7791762.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!