码迷,mamicode.com
首页 > 编程语言 > 详细

线性回归之梯度下降算法

时间:2019-11-07 23:24:10      阅读:112      评论:0      收藏:0      [点我收藏+]

标签:不同   toc   样本   变化   多元线性回归   copy   线性回归   i++   伪代码   

线性回归之梯度下降法

1.梯度的概念

梯度是一个向量,对于一个多元函数\(f\)而言,\(f\)在点\(P(x,y)\)的梯度是\(f\)在点\(P\)处增大最快的方向,即以f在P上的偏导数为分量的向量。以二元函数\(f(x,y)\)为例,向量\(\{\frac{\partial f}{\partial x},\frac{\partial f}{\partial y}\}|_{(x_0,y_0)}=f_x(x_0,y_0)\overrightarrow i+f_y(x_0,y_0)\overrightarrow j\)就是函数\(f(x,y)\)在点\(P(x_0,y_0)\)处的梯度,记作\(gradf(x,y)\)或者\(\nabla f(x,y)\)

2.梯度下降法

对于梯度下降,我们可以形象地理解为一个人下山的过程。假设现在有一个人在山上,现在他想要走下山,但是他不知道山底在哪个方向,怎么办呢?显然我们可以想到的是,一定要沿着山高度下降的地方走,不然就不是下山而是上山了。山高度下降的方向有很多,选哪个方向呢?这个人比较有冒险精神,他选择最陡峭的方向,即山高度下降最快的方向。现在确定了方向,就要开始下山了。又有一个问题来了,在下山的过程中,最开始选定的方向并不总是高度下降最快的地方。这个人比较聪明,他每次都选定一段距离,每走一段距离之后,就重新确定当前所在位置的高度下降最快的地方。这样,这个人每次下山的方向都可以近似看作是每个距离段内高度下降最快的地方。现在我们将这个思想引入线性回归,在线性回归中,我们要找到参数矩阵\(\theta\)使得损失函数\(J(\theta)\)最小。如果把损失函数\(J(\theta)\)看作是这座山,山底不就是损失函数最小的地方吗,那我们求解参数矩阵\(\theta\)的过程,就是人走到山底的过程。

技术图片

如图所示,这是一元线性回归(即假设函数\(h_\theta(x) = \theta_0+\theta_1x\))中的损失函数图像,一开始我们选定一个起始点(通常是\((\theta_0=0,\theta_1=0)\)),然后沿着这个起始点开始,沿着这一点处损失函数下降最快的方向(即该点的梯度负方向)走一小步,走完一步之后,到达第二个点,然后我们又沿着第二个点的梯度负方向走一小步,到达第三个点,以此类推,直到我们到底局部最低点。为什么是局部最低点呢?因为我们到达的这个点的梯度为0向量(通常是和0向量相差在某一个可接受的范围内),这说明这个点是损失函数的极小值点,并不一定是最小值点。

技术图片

从梯度下降法的思想,我们可以看到,最后得到的局部最低点与我们选定的起始点有关。通常情况下,如果起始点不同,最后得到的局部最低点也会不一样。

3.梯度下降算法描述

现在对于梯度下降法,有了一个直观形象的理解了。接下来,我们看一下梯度下降算法。首先,我们给在下山的例子中每一段路的距离取名叫学习率(Learning Rate,也称步长,用\(\alpha\)表示),把一次下山走一段距离叫做一次迭代。算法详细过程:

  1. 确定定参数的初始值,计算损失函数的偏导数

  2. 将参数代入偏导数计算出梯度。若梯度为0,结束;否则转到3

  3. 用步长乘以梯度,并对参数进行更新
  4. 重复2-3

对于多元线性回归来说,拟合函数为:

\[h_\theta(x) = \sum_\limits{i=0}^n\theta_ix_i =\theta_0+ \theta_1x_1 + \cdots+\theta_nx_n \tag{3.1}\]

损失函数为:

\[J(\theta)=\frac{1}{2m}\sum_\limits{i=0}^m(y^{(i)}-h_\theta(x^{(i)}))^2 \tag{3.2}\]

损失函数的偏导数为:

\[\frac{\partial J(\theta)}{\theta_i} = \frac{1}{m}\sum_\limits{j=1}^m(h_\theta(x^{(j)})-y^{(j)})x_i^{(j)}=\frac{1}{m}\sum_\limits{j=1}^m(\sum_\limits{i=0}^n\theta_ix_i^{(j)}-y^{(j)})x_i^{(j)}\quad (i=0,1,\dots,n) \tag{3.3}\]

每次更新参数的操作为:

\[\theta_i = \theta_i-\alpha\frac{\partial J(\theta)}{\theta_i} = \theta_i-\alpha\frac{1}{m}\sum_\limits{j=1}^m(h_\theta(x^{(j)})-y^{(j)})x_i^{(j)}\quad (i=0,1,\dots,n)\tag{3.4}\]

注意,更新参数时必须同步更新所有参数,不能先更新\(\theta_0\)再更新\(\theta_1\),如果用Java伪代码就是:

double []temp = new double[n+1]; //因为参数??从??0到????,所以一共n+1,temp数组表示每次下降的高度
for(int i = 0; i < temp.length();i++){  
  double sum = 0;
  for(int j = 0; j < m; j++){
    double hx = 0;
    for(int k = 0; k <= n; k++){
        hx += theta[k]*x[j][k];
    }
    sum += (hx - y[j])*x[j][i];
  }
  temp[i] = alpha/m*sum;    
}
for(int i = 0; i < n; i++)
  theta[i] = theta[i] -temp[i];

对于这种需要同步更新的,最好的方法是采用矩阵运算,Java代码是(采用commons-math3库):

theta = (theta.substract(alpha/m*(X.transpose().multiply(X.multiply(theta)-y))).copy;
//theta为(n+1)*1维,X为m*(n+1)维,y为m*1维

数学推导如下:

\[h_\theta(x) = \theta^Tx,J(\theta) = \frac{1}{2m}(X\theta-Y)^T(X\theta-Y),\frac{\partial J(\theta)}{\partial\theta}=\frac{1}{m}X^T(X\theta-Y),\theta = \theta - \alpha\frac{1}{m} X^T(X\theta-Y)\]

其中\(\theta\)为(n+1)*1维,\(X\)为m*(n+1)维,\(Y\)为m*1维,\(x\)为(n+1)*1维

4.特征缩放

在梯度下降中,如果样本的某些维度取值范围并不合理,比如房价预测系统中将房价的单位采用万亿元/平方米,这样每一个样本值的房价会非常小,以至于在图形表示时,几乎是一条水平线,不同的拟合直线所计算出的损失函数之间的差值非常小。这样采用直线取进行拟合时,如果我们在梯度下降法中循环结束的标志是梯度值为0,理论上这样我们仍然可以计算出准确的参数值。但是通常情况下,我们采用的标志是梯度与0相差在某一个可接受的范围内,如果我们选定这个范围是0.0001,0.0001乘以万亿,其结果是一亿,这么大的误差显然是不能接受的。所以,我们有必要进行特征缩放,通常采用标准化的方式来进行特征缩放,即\(x_i = \frac{x_i-\overline x_i}{\sigma_i}\)。特征缩放之后,还可以加快我们的收敛速度。

5.其他梯度下降算法

在上文中,我们介绍了梯度下降法的思想和算法步骤,并给出了数学证明。实际上,文中介绍的是批量梯度下降法(Batch Gradient Descent, BSD)。梯度下降法,还有随机梯度下降法(Stochastic Gradient Descent, SGD)和小批量梯度下降法(Mini-batch Gradient Descent, MBSD)。它们之间的区别仅在于更新参数\(\theta\)的方式不同,即采用全体样本、随机样本或部分样本。

6.学习率的选择

在梯度下降算法中,迭代步长(即学习率)的选择非常重要。如果步长太大,最后可能不收敛,即出现振荡。如果步长太小,那么收敛速度太慢,我们需要很多次迭代来到达局部最优解。下图是对于某一房价预测系统,选择不同学习率时,损失函数随迭代次数的变化:

技术图片

可以看到在第一张图片中,当\(\alpha=1.6\)时,迭代次数大于40之后,损失函数明显越来越大。

技术图片

在第二张图片中,当\(\alpha = 0.00001\)时,损失函数每一次迭代减少的非常小,收敛速度很慢,我们可能需要很多次迭代才能得到局部最优解。

有时候,在梯度下降算法中,学习率并不总是固定的,有时候也会依据梯度来改变学习率。

参考链接:

梯度下降小结

梯度下降法小结

线性回归之梯度下降算法

标签:不同   toc   样本   变化   多元线性回归   copy   线性回归   i++   伪代码   

原文地址:https://www.cnblogs.com/liyier/p/11816925.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!