码迷,mamicode.com
首页 > 其他好文 > 详细

Spark 决策树--回归模型

时间:2017-11-06 16:20:07      阅读:344      评论:0      收藏:0      [点我收藏+]

标签:rri   1.4   回归   build   oop   row   mat   highlight   line   

package Spark_MLlib

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}

case class data_scheam(features:Vector,label:String)
object 决策树__回归模型 {
      val spark=SparkSession.builder().master("local").getOrCreate()
      import spark.implicits._
  def main(args: Array[String]): Unit = {
     val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo2.txt")
                .map(_.split(",")).map(x=>data_schema(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF()
       val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
       val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data)
      val labelCoverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
      val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3))
    //决策树回归模型构造设置
      val dtRegressor=new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
    //构造机器学习工作流
      val pipelineRegressor=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,dtRegressor,labelCoverter))
    //训练决策树回归模型
      val modelRegressor=pipelineRegressor.fit(trainData)
     //进行预测
      val prediction=modelRegressor.transform(testData)
      prediction.show(150)
    //评估决策树回归模型
      val evaluatorRegressor=new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse") //setMetricName:设置决定你的度量标准是均方根误差还是均方误差等,值可以为:rmse,mse,r2,mae
val Root_Mean_Squared_Error=evaluatorRegressor.evaluate(prediction) println("均方根误差为: "+Root_Mean_Squared_Error) val treeModelRegressor=modelRegressor.stages(2).asInstanceOf[DecisionTreeRegressionModel] val schema_decisionTree=treeModelRegressor.toDebugString println("决策树分类模型的结构为: "+schema_decisionTree) } }
Spark 源码:关于setMetricName("")
@Since("2.0.0")
  override def evaluate(dataset: Dataset[_]): Double = {
    val schema = dataset.schema
    SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
    SchemaUtils.checkNumericType(schema, $(labelCol))

    val predictionAndLabels = dataset
      .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
      .rdd
      .map { case Row(prediction: Double, label: Double) => (prediction, label) }
    val metrics = new RegressionMetrics(predictionAndLabels)
    val metric = $(metricName) match {
      case "rmse" => metrics.rootMeanSquaredError
      case "mse" => metrics.meanSquaredError
      case "r2" => metrics.r2
      case "mae" => metrics.meanAbsoluteError
    }
    metric
}

结果:

+-----------------+------+------------+-----------------+----------+--------------+
|         features| label|indexedLabel|  indexedFeatures|prediction|predictedLabel|
+-----------------+------+------------+-----------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|hadoop|         1.0|[4.6,3.1,1.5,0.2]|       1.0|        hadoop|
|[4.6,3.4,1.4,0.3]|hadoop|         1.0|[4.6,3.4,1.4,0.3]|       1.0|        hadoop|
|[4.7,3.2,1.3,0.2]|hadoop|         1.0|[4.7,3.2,1.3,0.2]|       1.0|        hadoop|
|[4.8,3.0,1.4,0.1]|hadoop|         1.0|[4.8,3.0,1.4,0.1]|       1.0|        hadoop|
|[5.1,3.3,1.7,0.5]|hadoop|         1.0|[5.1,3.3,1.7,0.5]|       1.0|        hadoop|
|[5.1,3.7,1.5,0.4]|hadoop|         1.0|[5.1,3.7,1.5,0.4]|       1.0|        hadoop|
|[5.4,3.9,1.3,0.4]|hadoop|         1.0|[5.4,3.9,1.3,0.4]|       1.0|        hadoop|
|[5.5,2.3,4.0,1.3]| spark|         0.0|[5.5,2.3,4.0,1.3]|       0.0|         spark|
|[5.5,3.5,1.3,0.2]|hadoop|         1.0|[5.5,3.5,1.3,0.2]|       1.0|        hadoop|
|[5.6,2.7,4.2,1.3]| spark|         0.0|[5.6,2.7,4.2,1.3]|       0.0|         spark|
|[5.6,3.0,4.1,1.3]| spark|         0.0|[5.6,3.0,4.1,1.3]|       0.0|         spark|
|[5.6,3.0,4.5,1.5]| spark|         0.0|[5.6,3.0,4.5,1.5]|       0.0|         spark|
|[5.7,2.6,3.5,1.0]| spark|         0.0|[5.7,2.6,3.5,1.0]|       0.0|         spark|
|[5.7,4.4,1.5,0.4]|hadoop|         1.0|[5.7,4.4,1.5,0.4]|       1.0|        hadoop|
|[5.8,2.7,3.9,1.2]| spark|         0.0|[5.8,2.7,3.9,1.2]|       0.0|         spark|
|[5.8,2.7,4.1,1.0]| spark|         0.0|[5.8,2.7,4.1,1.0]|       0.0|         spark|
|[5.8,2.8,5.1,2.4]| Scala|         2.0|[5.8,2.8,5.1,2.4]|       2.0|         Scala|
|[5.8,4.0,1.2,0.2]|hadoop|         1.0|[5.8,4.0,1.2,0.2]|       1.0|        hadoop|
|[5.9,3.0,4.2,1.5]| spark|         0.0|[5.9,3.0,4.2,1.5]|       0.0|         spark|
|[5.9,3.0,5.1,1.8]| Scala|         2.0|[5.9,3.0,5.1,1.8]|       2.0|         Scala|
|[5.9,3.2,4.8,1.8]| spark|         0.0|[5.9,3.2,4.8,1.8]|       2.0|         Scala|
|[6.1,2.6,5.6,1.4]| Scala|         2.0|[6.1,2.6,5.6,1.4]|       2.0|         Scala|
|[6.1,2.8,4.0,1.3]| spark|         0.0|[6.1,2.8,4.0,1.3]|       0.0|         spark|
|[6.3,2.9,5.6,1.8]| Scala|         2.0|[6.3,2.9,5.6,1.8]|       2.0|         Scala|
|[6.3,3.4,5.6,2.4]| Scala|         2.0|[6.3,3.4,5.6,2.4]|       2.0|         Scala|
|[6.4,2.7,5.3,1.9]| Scala|         2.0|[6.4,2.7,5.3,1.9]|       2.0|         Scala|
|[6.4,3.1,5.5,1.8]| Scala|         2.0|[6.4,3.1,5.5,1.8]|       2.0|         Scala|
|[6.4,3.2,4.5,1.5]| spark|         0.0|[6.4,3.2,4.5,1.5]|       0.0|         spark|
|[6.5,2.8,4.6,1.5]| spark|         0.0|[6.5,2.8,4.6,1.5]|       0.0|         spark|
|[6.5,3.0,5.5,1.8]| Scala|         2.0|[6.5,3.0,5.5,1.8]|       2.0|         Scala|
|[6.7,3.0,5.2,2.3]| Scala|         2.0|[6.7,3.0,5.2,2.3]|       2.0|         Scala|
|[6.7,3.1,4.7,1.5]| spark|         0.0|[6.7,3.1,4.7,1.5]|       0.0|         spark|
|[6.8,3.0,5.5,2.1]| Scala|         2.0|[6.8,3.0,5.5,2.1]|       2.0|         Scala|
|[6.9,3.1,5.4,2.1]| Scala|         2.0|[6.9,3.1,5.4,2.1]|       2.0|         Scala|
|[7.0,3.2,4.7,1.4]| spark|         0.0|[7.0,3.2,4.7,1.4]|       0.0|         spark|
|[7.1,3.0,5.9,2.1]| Scala|         2.0|[7.1,3.0,5.9,2.1]|       2.0|         Scala|
|[7.2,3.0,5.8,1.6]| Scala|         2.0|[7.2,3.0,5.8,1.6]|       0.0|         spark|
|[7.2,3.2,6.0,1.8]| Scala|         2.0|[7.2,3.2,6.0,1.8]|       2.0|         Scala|
|[7.2,3.6,6.1,2.5]| Scala|         2.0|[7.2,3.6,6.1,2.5]|       2.0|         Scala|
|[7.4,2.8,6.1,1.9]| Scala|         2.0|[7.4,2.8,6.1,1.9]|       2.0|         Scala|
|[7.7,2.6,6.9,2.3]| Scala|         2.0|[7.7,2.6,6.9,2.3]|       2.0|         Scala|
|[7.7,2.8,6.7,2.0]| Scala|         2.0|[7.7,2.8,6.7,2.0]|       2.0|         Scala|
+-----------------+------+------------+-----------------+----------+--------------+

均方根误差为: 0.43643578047198484
决策树分类模型的结构为: DecisionTreeRegressionModel (uid=dtr_6015411b1a3d) of depth 4 with 11 nodes
  If (feature 3 <= 1.7)
   If (feature 2 <= 1.9)
    Predict: 1.0
   Else (feature 2 > 1.9)
    If (feature 2 <= 4.9)
     If (feature 3 <= 1.6)
      Predict: 0.0
     Else (feature 3 > 1.6)
      Predict: 2.0
    Else (feature 2 > 4.9)
     If (feature 3 <= 1.5)
      Predict: 2.0
     Else (feature 3 > 1.5)
      Predict: 0.0
  Else (feature 3 > 1.7)
   Predict: 2.0

Spark 决策树--回归模型

标签:rri   1.4   回归   build   oop   row   mat   highlight   line   

原文地址:http://www.cnblogs.com/soyo/p/7793664.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!