码迷,mamicode.com
首页 > 编程语言 > 详细

多元线性回归----Java简单实现

时间:2014-10-30 13:08:29      阅读:348      评论:0      收藏:0      [点我收藏+]

标签:style   blog   http   io   color   os   ar   java   for   

学习Andrew N.g的机器学习课程之后的简单实现.

课程地址:https://class.coursera.org/ml-007

 不大会编辑公式,所以略去具体的推导,有疑惑的同学去看看Andrew 的课程吧,顺带一句,Andrew的课程实在是很赞。

如果还有疑问,feel free to contact me via emails or QQ.

 

LinearRegression.java

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;

public class LinearRegression {
    /*
     * 训练数据示例:
     *   x0        x1        x2        y 
        1.0       1.0       2.0       7.2 
        1.0       2.0       1.0       4.9 
        1.0       3.0       0.0       2.6 
        1.0       4.0       1.0       6.3 
        1.0       5.0      -1.0       1.0 
        1.0       6.0       0.0       4.7 
        1.0       7.0      -2.0      -0.6 
        注意!!!!x1,x2,y三列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。
        x0,x1,x2是“特征”,y是结果
        
        h(x) = theta0 * x0 + theta1* x1 + theta2 * x2
        
        theta0,theta1,theta2 是想要训练出来的参数
        
         此程序采用“梯度下降法”
        
     * 
     */

    private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
    private int row;//训练数据  行数
    private int column;//训练数据 列数
    
    private double [] theta;//参数theta
    
    private double alpha;//训练步长
    private int iteration;//迭代次数
    
    public LinearRegression(String fileName)
    {   
        int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
        int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数
        
        trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
        this.row=rowoffile;
        this.column=columnoffile+1;
        
        this.alpha = 0.001;//步长默认为0.001
        this.iteration=100000;//迭代次数默认为 100000
        
        theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
        initialize_theta();
        
        loadTrainDataFromFile(fileName,rowoffile,columnoffile);
    }
    public LinearRegression(String fileName,double alpha,int iteration)
    {   
        int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
        int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数
        
        trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
        this.row=rowoffile;
        this.column=columnoffile+1;
        
        this.alpha = alpha;
        this.iteration=iteration;
        
        theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
        initialize_theta();
        
        loadTrainDataFromFile(fileName,rowoffile,columnoffile);
    }
    
    
    private int getRowNumber(String fileName)
    {
        int count =0;
        File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            while ( reader.readLine() != null) 
                count++;
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return count;
        
    }
    
    private int getColumnNumber(String fileName)
    {
        int count =0;
        File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = reader.readLine();
            if(tempString!=null)
                count = tempString.split(" ").length;
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return count;
    }
    
    private void initialize_theta()//将theta各个参数全部初始化为1.0
    {
        for(int i=0;i<theta.length;i++)
            theta[i]=1.0;
    }
    
    public void trainTheta()
    {
        int iteration = this.iteration;
        while( (iteration--)>0 )
        {
                //对每个theta i 求 偏导数
            double [] partial_derivative = compute_partial_derivative();//偏导数
                //更新每个theta
            for(int i =0; i< theta.length;i++)
                theta[i]-= alpha * partial_derivative[i];
        }
    }
    
    private double [] compute_partial_derivative()
    {
        double [] partial_derivative = new double[theta.length];
        for(int j =0;j<theta.length;j++)//遍历,对每个theta求偏导数
        {
            partial_derivative[j]= compute_partial_derivative_for_theta(j);//对 theta i 求 偏导
        }
        return partial_derivative;
    }
    private double compute_partial_derivative_for_theta(int j)
    {
        double sum=0.0;
        for(int i=0;i<row;i++)//遍历 每一行数据
        {
            sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);
        }
        return sum/row;
    }
    private double h_theta_x_i_minus_y_i_times_x_j_i(int i,int j)
    {
        double[] oneRow = getRow(i);//取一行数据,前面是feature,最后一个是y
        double result = 0.0;
        
        for(int k=0;k< (oneRow.length-1);k++)
            result+=theta[k]*oneRow[k];
        result-=oneRow[oneRow.length-1];
        result*=oneRow[j];
        return result;
    }
    private double [] getRow(int i)//从训练数据中取出第i行,i=0,1,2,。。。,(row-1)
    {
        return trainData[i];
    }
    
    
    private void loadTrainDataFromFile(String fileName,int row, int column)
    {   
        for(int i=0;i< row;i++)//trainData的第一列全部置为1.0(feature x0)
            trainData[i][0]=1.0;
        
        File file = new File(fileName);
        BufferedReader reader = null;
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = null;
            int counter = 0;
            while ( (counter<row) && (tempString = reader.readLine()) != null) {
                String [] tempData = tempString.split(" ");
                for(int i=0;i<column;i++)
                    trainData[counter][i+1]=Double.parseDouble(tempData[i]);
                counter++;
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
    }
    
    public void printTrainData()
    {
        System.out.println("Train Data:\n");
        for(int i=0;i<column-1;i++)
            System.out.printf("%10s","x"+i+" ");
        System.out.printf("%10s","y"+" \n");
        for(int i=0;i<row;i++)
        {
            for(int j=0;j<column;j++)
            {
                System.out.printf("%10s",trainData[i][j]+" ");
            }
            System.out.println();
        }
        System.out.println();
    }
    
    public void printTheta()
    {
        for(double a:theta)
            System.out.print(a+" ");
    }

}

TestLinearRegression.java

public class TestLinearRegression {

    public static void main(String[] args) {
        // TODO Auto-generated method stub
         LinearRegression m = new LinearRegression("trainData",0.001,1000000);
         m.printTrainData();
         m.trainTheta();
         m.printTheta();
    }

}

trainData文件中是训练数据,默认最后一列是y,比如:

             1.0       2.0       7.2
             2.0       1.0       4.9
             3.0       0.0       2.6
             4.0       1.0       6.3
             5.0      -1.0       1.0
            6.0       0.0       4.7
            7.0      -2.0      -0.6

前两列是“feature”,最后一列,也就是第三列是y

 

Email: wuzimian2006@163.com

 QQ:    726590906    

多元线性回归----Java简单实现

标签:style   blog   http   io   color   os   ar   java   for   

原文地址:http://www.cnblogs.com/wzm-xu/p/4062266.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!