码迷,mamicode.com
首页 > 其他好文 > 详细

线性回归

时间:2015-09-11 23:36:37      阅读:494      评论:0      收藏:0      [点我收藏+]

标签:

线性回归简介

线性回归是指利用线性回归方程中的最小平方函数对一个或多个自变量和因变量之间的关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。

案例简介

本案例中的数据是一组男孩年龄与身高的数据,我们将年龄作为自变量,身高作为因变量,二组数据分别从.dat文件中读取,最后拟合出一个线性关系式。

具体分析见http://openclassroom.stanford.edu/MainFolder/DocumentPage.php?course=MachineLearning&doc=exercises/ex2/ex2.html

我的解决代码

首先提供我的MATLAB代码:

% 这是一个线性回归案例代码

x=load(‘ex2Data\ex2x.dat‘);
y=load(‘ex2Data\ex2y.dat‘);
plot(x,y,‘o‘);
xlabel(‘年龄‘);
ylabel(‘身高‘);
hold on;

%%下面开始进行迭代
m=length(x);
x=[ones(m,1),x];
theta=[0,0];
alpha=0.07;
J_theta=1;
tmp=0;
count=0;

while abs(J_theta-tmp)>0.0000000000001
    count=count+1;
    h_theta=theta*(x‘);
    tmp=J_theta;
    J_theta=1/(2*m)*sum((h_theta-y‘).^2);
    theta(1)=theta(1)-alpha/m*sum((h_theta-y‘)*x(:,1));
    theta(2)=theta(2)-alpha/m*sum((h_theta-y‘)*x(:,2));
end
plot(x,theta(2)*x+theta(1),‘r‘,‘LineWidth‘,2);
fprintf(‘迭代:%d次\n‘,count);
fprintf(‘y=%f*x+%f\n‘,theta(2),theta(1));
fprintf(‘预测,年龄为18的男孩,身高为%f米\n‘,theta(2)*18+theta(1));


下面是我的Java代码:

package linearRegression;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;

import machinelearning.Matrix;

/**
 * 线性回归模型
 * @author zzw922cn
 * @version 2015-9-11
 */
public class Main {

	public static void main(String[] args) throws Exception {
		int size=50;
		Matrix x = new Matrix(size,2);
		Matrix y = new Matrix(size,1);
		Matrix theta = new Matrix(1,2);
		Matrix h_theta = new Matrix(1,size);
		double alpha= 0.07;
		double J_theta=1;
		double tmp=0;
		int count=0;
		
		//读取数据存进数组
		File xFile = new File("ex2x.dat");
		File yFile = new File("ex2y.dat");
		FileReader xFileReader = new FileReader(xFile);
		FileReader yFileReader = new FileReader(yFile);
		BufferedReader xBufferedReader = new BufferedReader(xFileReader);
		BufferedReader yBufferedReader = new BufferedReader(yFileReader);
		String xline=null;
		String yline=null;
		int index = 0;
		while((xline=xBufferedReader.readLine())!=null&&(yline=yBufferedReader.readLine())!=null) {
			x.setToSpecifiedValue(index, 0, 1);
			x.setToSpecifiedValue(index, 1, Float.parseFloat(xline));
			y.setToSpecifiedValue(index, 0, Float.parseFloat(yline));
			index++;
		}
		xBufferedReader.close();
		yBufferedReader.close();
		
		//开始运用梯度下降算法更新参数,直到参数收敛为止
		double theta0;
		double theta1;
		double tmp0;
		double tmp1; 
		while(Math.abs(J_theta-tmp)>0.0000000001) {
			count++;
			h_theta=theta.multiply(x.transpose());
			tmp=J_theta;
			J_theta = h_theta.minus(y.transpose()).dotBySelf().sum()*(1.0/(2*size));
			theta0 = theta.getElement(0, 0);
			theta1 = theta.getElement(0, 1);
			tmp0 = h_theta.minus(y.transpose()).multiply(x.getColoum(0)).sum()*alpha/size;
			tmp1 = h_theta.minus(y.transpose()).multiply(x.getColoum(1)).sum()*alpha/size;
			theta.setToSpecifiedValue(0, 0, (float) (theta0-tmp0));
			theta.setToSpecifiedValue(0, 1, (float) (theta1-tmp1));
		}
		System.out.println("迭代"+count+"次");
		System.out.println("拟合表达式为y="+theta.getElement(0, 1)+"*x+"+theta.getElement(0, 0));
		System.out.println("预测,年龄为18的男孩,身高为"+(theta.getElement(0, 1)*18+theta.getElement(0, 0))+"米");
	}
}

运行结果:

技术分享

线性回归

标签:

原文地址:http://my.oschina.net/zzw922cn/blog/505162

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!