码迷,mamicode.com
首页 > 其他好文 > 详细

Spark学习笔记——手写数字识别

时间:2017-05-26 00:38:43      阅读:518      评论:0      收藏:0      [点我收藏+]

标签:rate   svm   create   line   str   for   bsp   cal   ice   

import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.regression.RandomForestRegressor
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, SVMWithSGD}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.optimization.L1Updater
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, RandomForest}
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.impurity.Entropy

/**
  * Created by common on 17-5-17.
  */

case class LabeledPic(
                       label: Int,
                       pic: List[Double] = List()
                     )

object DigitRecognizer {

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setAppName("DigitRecgonizer").setMaster("local")
    val sc = new SparkContext(conf)
    // 去掉第一行,sed 1d train.csv > train_noheader.csv
    val trainFile = "file:///media/common/工作/kaggle/DigitRecognizer/train_noheader.csv"
    val trainRawData = sc.textFile(trainFile)
    // 通过逗号对数据进行分割,生成数组的rdd
    val trainRecords = trainRawData.map(line => line.split(","))

    val trainData = trainRecords.map { r =>
      val label = r(0).toInt
      val features = r.slice(1, r.size).map(d => d.toDouble)
      LabeledPoint(label, Vectors.dense(features))
    }


    //    // 使用贝叶斯模型
    //    val nbModel = NaiveBayes.train(trainData)
    //
    //    val nbTotalCorrect = trainData.map { point =>
    //      if (nbModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val nbAccuracy = nbTotalCorrect / trainData.count
    //
    //    println("贝叶斯模型正确率:" + nbAccuracy)
    //
    //    // 对测试数据进行预测
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通过逗号对数据进行分割,生成数组的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = nbModel.predict(testData).map(p => p.toInt)
    //    // 保存预测结果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")


    //    // 使用线性回归模型
    //    val lrModel = new LogisticRegressionWithLBFGS()
    //      .setNumClasses(10)
    //      .run(trainData)
    //
    //    val lrTotalCorrect = trainData.map { point =>
    //      if (lrModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val lrAccuracy = lrTotalCorrect / trainData.count
    //
    //    println("线性回归模型正确率:" + lrAccuracy)
    //
    //    // 对测试数据进行预测
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通过逗号对数据进行分割,生成数组的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = lrModel.predict(testData).map(p => p.toInt)
    //    // 保存预测结果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict1")


    //    // 使用决策树模型
    //    val maxTreeDepth = 10
    //    val numClass = 10
    //    val dtModel = DecisionTree.train(trainData, Algo.Classification, Entropy, maxTreeDepth, numClass)
    //
    //    val dtTotalCorrect = trainData.map { point =>
    //      if (dtModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val dtAccuracy = dtTotalCorrect / trainData.count
    //
    //    println("决策树模型正确率:" + dtAccuracy)
    //
    //    // 对测试数据进行预测
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通过逗号对数据进行分割,生成数组的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = dtModel.predict(testData).map(p => p.toInt)
    //    // 保存预测结果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict2")


//    // 使用随机森林模型
//    val numClasses = 30
//    val categoricalFeaturesInfo = Map[Int, Int]()
//    val numTrees = 50
//    val featureSubsetStrategy = "auto"
//    val impurity = "gini"
//    val maxDepth = 10
//    val maxBins = 32
//    val rtModel = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
//
//    val rtTotalCorrect = trainData.map { point =>
//      if (rtModel.predict(point.features) == point.label) 1 else 0
//    }.sum
//    val rtAccuracy = rtTotalCorrect / trainData.count
//
//    println("随机森林模型正确率:" + rtAccuracy)
//
//    // 对测试数据进行预测
//    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
//    // 通过逗号对数据进行分割,生成数组的rdd
//    val testRecords = testRawData.map(line => line.split(","))
//
//    val testData = testRecords.map { r =>
//      val features = r.map(d => d.toDouble)
//      Vectors.dense(features)
//    }
//    val predictions = rtModel.predict(testData).map(p => p.toInt)
//    // 保存预测结果
//    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")


  }

}

 

Spark学习笔记——手写数字识别

标签:rate   svm   create   line   str   for   bsp   cal   ice   

原文地址:http://www.cnblogs.com/tonglin0325/p/6906524.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!