码迷,mamicode.com
首页 > 其他好文 > 详细

SparkMLlib之 logistic regression源码分析

时间:2016-01-16 11:58:51      阅读:662      评论:0      收藏:0      [点我收藏+]

标签:

最近在研究机器学习,使用的工具是spark,本文是针对spar最新的源码Spark1.6.0的MLlib中的 logistic regression, linear regression进行源码分析,其理论部分参考:http://www.cnblogs.com/ljy2013/p/5129610.html

下面我们跟随我的demo来一步一步解剖源码,首先来看一下我的demo:

技术分享
 1 package org.apache.spark.mllib.classification
 2 
 3 import org.apache.spark.SparkContext
 4 import org.apache.spark.mllib.classification.{ LogisticRegressionWithLBFGS, LogisticRegressionModel }
 5 import org.apache.spark.mllib.evaluation.MulticlassMetrics
 6 import org.apache.spark.mllib.regression.LabeledPoint
 7 import org.apache.spark.mllib.linalg.Vectors
 8 import org.apache.spark.mllib.util.MLUtils
 9 import org.apache.spark.SparkConf
10 
11 object MyLogisticRegression {
12   def main(args: Array[String]): Unit = {
13 
14     val conf = new SparkConf().setAppName("Simple Application").setMaster("local[*]")
15     val sc = new SparkContext(conf)
16 
17     // Load training data in LIBSVM format.  这里的数据格式是LIBSVM格式:<label> <index1>:<value1> <index2>:<value2> ...index1是按1开始的
18     val data = MLUtils.loadLibSVMFile(sc, "D:\\MyFile\\wine.txt")
19 
20     // Split data into training (60%) and test (40%).
21     val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
22     val training = splits(0).cache()
23     val test = splits(1)
24 
25     // Run training algorithm to build the model
26     val model = new LogisticRegressionWithLBFGS()
27       .setNumClasses(10) //设置类别的个数
28       .run(training)
29 
30     // Compute raw scores on the test set.
31     val predictionAndLabels = test.map {
32       case LabeledPoint(label, features) =>
33         val prediction = model.predict(features)
34         (prediction, label)
35     }
36 
37     // Get evaluation metrics.
38     val metrics = new MulticlassMetrics(predictionAndLabels)
39     val precision = metrics.precision
40     println("Precision = " + precision)
41 
42     // Save and load model
43     model.save(sc, "myModelPath")
44     val sameModel = LogisticRegressionModel.load(sc, "myModelPath")
45 
46   }
47 }
View Code

从上面的demo,我们可以看出LogisticRegression采用的是LBFGS算法来进行优化求参数的,LBFGS是一个无约束项优化算法,主要用来求解逻辑回归的参数(权值)。不清楚的同学可以参考:http://www.cnblogs.com/ljy2013/p/5129610.html 。

我将其中的类继承图简单的画了一下:

技术分享

主要分了两个过程:训练和预测。

1、训练过程

首先主程序通过调用下面的方法来进行训练

技术分享
1     // Run training algorithm to build the model
2     val model = new LogisticRegressionWithLBFGS()
3       .setNumClasses(10) //设置类别的个数
4       .run(training)
View Code

通过设置对应的类别的个数,然后调用LogisticRegressionWithLBFGS的run方法,但是LogisticRegressionWithLBFGS类本身是没有该方法的,但它继承自GeneralizedLinearAlgorithm类的run方法,训练过程就是在这个方法中完成的,现在让我们来看一下这个方法:

技术分享
  1   def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
  2 
  3     if (numFeatures < 0) {
  4       numFeatures = input.map(_.features.size).first()
  5     }
  6     //由于需要多次迭代,因此需要将训练数据缓存到内存中
  7     if (input.getStorageLevel == StorageLevel.NONE) {
  8       logWarning("The input data is not directly cached, which may hurt performance if its"
  9         + " parent RDDs are also uncached.")
 10     }
 11 
 12     // Check the data properties before running the optimizer
 13     if (validateData && !validators.forall(func => func(input))) {
 14       throw new SparkException("Input validation failed.")
 15     }
 16 
 17     /**
 18      * Scaling columns to unit variance as a heuristic to reduce the condition number:
 19      *
 20      * During the optimization process, the convergence (rate) depends on the condition number of
 21      * the training dataset. Scaling the variables often reduces this condition number
 22      * heuristically, thus improving the convergence rate. Without reducing the condition number,
 23      * some training datasets mixing the columns with different scales may not be able to converge.
 24      *
 25      * GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return
 26      * the weights in the original scale.
 27      * See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
 28      *
 29      * Here, if useFeatureScaling is enabled, we will standardize the training features by dividing
 30      * the variance of each column (without subtracting the mean), and train the model in the
 31      * scaled space. Then we transform the coefficients from the scaled space to the original scale
 32      * as GLMNET and LIBSVM do.
 33      *
 34      * Currently, it‘s only enabled in LogisticRegressionWithLBFGS
 35      */
 36     //将数据标准化
 37     val scaler = if (useFeatureScaling) {
 38       new StandardScaler(withStd = true, withMean = false).fit(input.map(_.features))
 39     } else {
 40       null
 41     }
 42 
 43     // Prepend an extra variable consisting of all 1.0‘s for the intercept.
 44     // TODO: Apply feature scaling to the weight vector instead of input data.
 45     val data =
 46       if (addIntercept) {
 47         if (useFeatureScaling) {
 48           input.map(lp => (lp.label, appendBias(scaler.transform(lp.features)))).cache()
 49         } else {
 50           input.map(lp => (lp.label, appendBias(lp.features))).cache()
 51         }
 52       } else {
 53         if (useFeatureScaling) {
 54           input.map(lp => (lp.label, scaler.transform(lp.features))).cache()
 55         } else {
 56           input.map(lp => (lp.label, lp.features))
 57         }
 58       }
 59 
 60     /**
 61      * TODO: For better convergence, in logistic regression, the intercepts should be computed
 62      * from the prior probability distribution of the outcomes; for linear regression,
 63      * the intercept should be set as the average of response.
 64      */
 65     val initialWeightsWithIntercept = if (addIntercept && numOfLinearPredictor == 1) {
 66       appendBias(initialWeights)
 67     } else {
 68       /** If `numOfLinearPredictor > 1`, initialWeights already contains intercepts. */
 69       initialWeights
 70     }
 71 
 72     //采用优化器对权值进行优化,返回优化好的权值,即最终的模型参数
 73     val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
 74 
 75     val intercept = if (addIntercept && numOfLinearPredictor == 1) {
 76       weightsWithIntercept(weightsWithIntercept.size - 1)
 77     } else {
 78       0.0
 79     }
 80 
 81     var weights = if (addIntercept && numOfLinearPredictor == 1) {
 82       Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
 83     } else {
 84       weightsWithIntercept
 85     }
 86 
 87     /**
 88      * The weights and intercept are trained in the scaled space; we‘re converting them back to
 89      * the original scale.
 90      *
 91      * Math shows that if we only perform standardization without subtracting means, the intercept
 92      * will not be changed. w_i = w_i‘ / v_i where w_i‘ is the coefficient in the scaled space, w_i
 93      * is the coefficient in the original space, and v_i is the variance of the column i.
 94      */
 95     if (useFeatureScaling) {
 96       if (numOfLinearPredictor == 1) {
 97         weights = scaler.transform(weights)
 98       } else {
 99         /**
100          * For `numOfLinearPredictor > 1`, we have to transform the weights back to the original
101          * scale for each set of linear predictor. Note that the intercepts have to be explicitly
102          * excluded when `addIntercept == true` since the intercepts are part of weights now.
103          */
104         var i = 0
105         val n = weights.size / numOfLinearPredictor
106         val weightsArray = weights.toArray
107         while (i < numOfLinearPredictor) {
108           val start = i * n
109           val end = (i + 1) * n - { if (addIntercept) 1 else 0 }
110 
111           val partialWeightsArray = scaler.transform(
112             Vectors.dense(weightsArray.slice(start, end))).toArray
113 
114           System.arraycopy(partialWeightsArray, 0, weightsArray, start, partialWeightsArray.size)
115           i += 1
116         }
117         weights = Vectors.dense(weightsArray)
118       }
119     }
120 
121     // Warn at the end of the run as well, for increased visibility.
122     if (input.getStorageLevel == StorageLevel.NONE) {
123       logWarning("The input data was not directly cached, which may hurt performance if its"
124         + " parent RDDs are also uncached.")
125     }
126 
127     // Unpersist cached data
128     if (data.getStorageLevel != StorageLevel.NONE) {
129       data.unpersist(false)
130     }
131 
132     createModel(weights, intercept)
133   }
View Code

这个方法中,第一步是实现训练数据进行标准化处理;

第二步,就是通过优化器算法进行求最优的权值。这里要注意一点:它是实现的方式是:val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)这里有一个应用到多态的特性。这里的optimizer是GeneralizedLinearAlgorithm类中的抽象方法,如下所示:

技术分享

但是子类LogisticRegressionWithLBFGS实现了该方法:(这样子的设计可以做到每一个算法可以有自己特有的优化算法来计算最优权值,但GeneralizedLinearAlgorithm类可以适用于所有的机器学习算法)

技术分享

好了,现在是第三步,创建算法模型。我们可以看到GeneralizedLinearAlgorithm的run方法中,创建模型就一句代码搞定:createModel(weights, intercept)。但其中包含了程序员的设计思想在里面。和上面optimizer类似,createModel(weights, intercept)方法也是用到了多态的方式来实现。首先,GeneralizedLinearAlgorithm类中定义了一个抽象的:createModel方法,如下所示:

 protected def createModel(weights: Vector, intercept: Double): M

子类LogisticRegressionWithLBFGS实现了该方法。如下所示:

技术分享
1   override protected def createModel(weights: Vector, intercept: Double) = {
2     if (numOfLinearPredictor == 1) {
3       //两类的模型
4       new LogisticRegressionModel(weights, intercept)
5     } else {
6       //多类的模型
7       new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
8     }
9   }
View Code

因此实际上调用的是对应LogisticRegressionWithLBFGS的createModel方法。

 

SparkMLlib之 logistic regression源码分析

标签:

原文地址:http://www.cnblogs.com/ljy2013/p/5135192.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!