码迷,mamicode.com
首页 > 其他好文 > 详细

梯度上升法求解Logistic回归

时间:2014-06-01 04:34:43      阅读:410      评论:0      收藏:0      [点我收藏+]

标签:c   style   class   blog   code   tar   

回顾上次内容:http://blog.csdn.net/acdreamers/article/details/27365941

 

经过上次对Logistic回归理论的学习,我们已经推导出取对数后的似然函数为

 

         bubuko.com,布布扣

 

现在我们的目的是求一个向量bubuko.com,布布扣,使得bubuko.com,布布扣最大。其中

 

         bubuko.com,布布扣   

 

         bubuko.com,布布扣

 

对这个似然函数求偏导后得到

 

         bubuko.com,布布扣

 

根据梯度上升算法有

 

         bubuko.com,布布扣

 

进一步得到

 

         bubuko.com,布布扣

 

我们可以初始化向量bubuko.com,布布扣为0,或者随机值,然后进行迭代达到指定的精度为止。

 

现在就来用C++一步一步实现Logistic回归,我们对文章末尾列出的数据进行训练。

 

首先,我们要对文本进行读取,在训练数据中,每一行代表一组训练数据,每一组有7个数字,第一个数字代表ID,

可以忽略之,2~6代表这组训练数据的特征输入,第7个数字代表输出,为0或者1。每个数据之间用一个空格隔开。

 

首先我们来研究如何一行一行读取文本,在C++中,读取文本的一行用getline()函数。

 

getline()函数表示读取文本的一行,返回的是读取的字节数,如果读取失败则返回-1。用法如下:

#include <iostream>
#include <string.h>
#include <fstream>
#include <string>
#include <stdio.h>

using namespace std;

int main()
{
    string filename = "data.in";
    ifstream file(filename.c_str());
    char s[1024];
    if(file.is_open())
    {
        while(file.getline(s,1024))
        {
            int x,y,z;
            sscanf(s,"%d %d %d",&x,&y,&z);
            cout<<x<<" "<<y<<" "<<z<<endl;
        }
    }
    return 0;
}


 拿到每一行后,可以把它们提取出来,进行系统输入。

 

代码:

#include <iostream>
#include <string.h>
#include <fstream>
#include <stdio.h>
#include <math.h>
#include <vector>

using namespace std;

struct Data
{
    vector<int> x;
    int y;
    Data(){}
    Data(vector<int> x,int y)
    {
        this->x = x;
        this->y = y;
    }
};

vector<double> w;
vector<Data> data;

void loadData(vector<Data>& data,string path)
{
    string filename = path;
    ifstream file(filename.c_str());
    char s[1024];
    if(file.is_open())
    {
        while(file.getline(s,1024))
        {
            Data tmp;
            int x1,x2,x3,x4,x5,x6,x7;
            sscanf(s,"%d %d %d %d %d %d %d",&x1,&x2,&x3,&x4,&x5,&x6,&x7);
            tmp.x.push_back(1);
            tmp.x.push_back(x2);
            tmp.x.push_back(x3);
            tmp.x.push_back(x4);
            tmp.x.push_back(x5);
            tmp.x.push_back(x6);
            tmp.y = x7;
            data.push_back(tmp);
        }
    }
}

void Init()
{
    w.clear();
    data.clear();
    loadData(data,"traindata.txt");
    for(int i=0;i<data[0].x.size();i++)
        w.push_back(0);
}

double WX(const vector<double>& w,const Data& data)
{
    double ans = 0;
    for(int i=0;i<w.size();i++)
        ans += w[i] * data.x[i];
    return ans;
}

double sigmoid(const vector<double>& w,const Data& data)
{
    double x = WX(w,data);
    double ans = exp(x) / (1 + exp(x));
    return ans;
}

double Lw(vector<double> w)
{
    double ans = 0;
    for(int i=0;i<data.size();i++)
    {
        double x = WX(w,data[i]);
        ans += data[i].y * x - log(1 + exp(x));
    }
    return ans;
}

void gradient(double alpha)
{
    for(int i=0;i<w.size();i++)
    {
        double tmp = 0;
        for(int j=0;j<data.size();j++)
            tmp += alpha * data[j].x[i] * (data[j].y - sigmoid(w,data[j]));
        w[i] += tmp;
    }
}

void display(int cnt,double objLw,double newLw)
{
    cout<<"第"<<cnt<<"次迭代:  ojLw = "<<objLw<<"  两次迭代的目标差为: "<<(newLw - objLw)<<endl;
    cout<<"参数w为: ";
    for(int i=0;i<w.size();i++)
        cout<<w[i]<<" ";
    cout<<endl;
    cout<<endl;
}

void Logistic()
{
    int cnt = 0;
    double alpha = 0.1;
    double delta = 0.0001;
    double objLw = Lw(w);
    gradient(alpha);
    double newLw = Lw(w);
    while(fabs(newLw - objLw) > delta)
    {
        objLw = newLw;
        gradient(alpha);
        newLw = Lw(w);
        cnt++;
        display(cnt,objLw,newLw);
    }
}

void Separator()
{
    vector<Data> data;
    loadData(data,"testdata.txt");
    cout<<"预测分类结果:"<<endl;
    for(int i=0;i<data.size();i++)
    {
        double p0 = 0;
        double p1 = 0;
        double x = WX(w,data[i]);
        p1 = exp(x) / (1 + exp(x));
        p0 = 1 - p1;
        cout<<"实例: ";
        for(int j=0;j<data[i].x.size();j++)
            cout<<data[i].x[j]<<" ";
        cout<<"所属类别为:";
        if(p1 >= p0) cout<<1<<endl;
        else cout<<0<<endl;
    }
}

int main()
{
    Init();
    Logistic();
    Separator();
    return 0;
}

 

训练数据:traindata.txt

10009 1 0 0 1 0 1
10025 0 0 1 2 0 0
10038 1 0 0 1 1 0
10042 0 0 0 0 1 0
10049 0 0 1 0 0 0
10113 0 0 1 0 1 0
10131 0 0 1 2 1 0
10160 1 0 0 0 0 0
10164 0 0 1 0 1 0
10189 1 0 1 0 0 0
10215 0 0 1 0 1 0
10216 0 0 1 0 0 0
10235 0 0 1 0 1 0
10270 1 0 0 1 0 0
10282 1 0 0 0 1 0
10303 2 0 0 0 1 0
10346 1 0 0 2 1 0
10380 2 0 0 0 1 0
10429 2 0 1 0 0 0
10441 0 0 1 0 1 0
10443 0 0 1 2 0 0
10463 0 0 0 0 0 0
10475 0 0 1 0 1 0
10489 1 0 1 0 1 1
10518 0 0 1 2 1 0
10529 1 0 1 0 0 0
10545 0 0 1 0 0 0
10546 0 0 0 2 0 0
10575 1 0 0 0 1 0
10579 2 0 1 0 0 0
10581 2 0 1 1 1 0
10600 1 0 1 1 0 0
10627 1 0 1 2 0 0
10653 1 0 0 1 1 0
10664 0 0 0 0 1 0
10691 1 1 0 0 1 0
10692 1 0 1 2 1 0
10711 0 0 0 0 1 0
10714 0 0 1 0 0 0
10739 1 0 1 1 1 0
10750 1 0 1 0 1 0
10764 2 0 1 2 0 0
10770 0 0 1 2 1 0
10780 0 0 1 0 1 0
10784 2 0 1 0 1 0
10785 0 0 1 0 1 0
10788 1 0 0 0 0 0
10815 1 0 0 0 1 0
10816 0 0 0 0 1 0
10818 0 0 1 2 1 0
11095 0 1 1 0 0 0
11146 0 1 0 0 1 0
11206 2 1 0 0 0 0
11223 2 1 0 0 0 0
11236 1 1 0 2 0 0
11244 1 1 0 0 0 1
11245 0 1 0 0 0 0
11278 2 1 0 0 1 0
11322 0 1 0 0 1 0
11326 2 1 0 2 1 0
11329 2 1 0 2 1 0
11344 1 1 0 2 1 0
11358 0 1 0 0 0 1
11417 2 1 1 0 1 0
11421 2 1 0 1 1 0
11484 1 1 0 0 0 1
11499 2 1 0 0 0 0
11503 1 1 0 0 1 0
11527 1 1 0 0 0 0
11540 2 1 0 1 1 0
11580 1 1 0 0 1 0
11583 1 0 1 1 0 1
11592 2 1 0 1 1 0
11604 0 1 0 0 1 0
11625 1 0 1 0 0 0
20035 0 0 1 0 0 1
20053 1 0 0 0 0 0
20070 0 0 0 2 1 0
20074 1 0 1 2 0 1
20146 1 0 0 1 1 0
20149 2 0 1 2 1 0
20158 2 0 0 0 1 0
20185 1 0 0 1 1 0
20193 1 0 1 0 1 0
20194 0 0 1 0 0 0
20205 1 0 0 2 1 0
20206 2 0 1 1 1 0
20265 0 0 1 0 1 0
20311 0 0 0 0 1 0
20328 2 0 0 1 0 1
20353 0 0 1 0 0 0
20372 0 0 0 0 0 0
20405 1 0 1 1 1 1
20413 2 0 1 0 1 0
20427 0 0 0 0 0 0
20455 1 0 1 0 1 0
20462 0 0 0 0 1 0
20472 0 0 0 2 0 0
20485 0 0 0 0 0 0
20523 0 0 1 2 0 0
20539 0 0 1 0 1 0
20554 0 0 1 0 0 1
20565 0 0 0 2 1 0
20566 1 0 1 1 1 0
20567 1 0 0 1 1 0
20568 0 0 1 0 1 0
20569 1 0 0 0 0 0
20571 1 0 1 0 1 0
20581 2 0 0 0 1 0
20583 1 0 0 0 1 0
20585 2 0 0 1 1 0
20586 0 0 1 2 1 0
20591 1 0 1 2 0 0
20595 0 0 1 2 1 0
20597 1 0 0 0 0 0
20599 0 0 1 0 1 0
20607 0 0 0 1 1 0
20611 1 0 0 0 1 0
20612 2 0 0 1 1 0
20614 1 0 0 1 1 0
20615 1 0 1 0 0 0
21017 1 1 0 1 1 0
21058 2 1 0 0 1 0
21063 0 1 0 0 0 0
21084 1 1 0 1 0 1
21087 1 1 0 2 1 0
21098 0 1 0 0 0 0
21099 1 1 0 2 0 0
21113 0 1 0 0 1 0
21114 1 1 0 0 1 1
21116 1 1 0 2 1 0
21117 1 0 0 2 1 0
21138 2 1 1 1 1 0
21154 0 1 0 0 1 0
21165 0 1 0 0 1 0
21181 2 1 0 0 0 1
21183 1 1 0 2 1 0
21231 1 1 0 0 1 0
21234 1 1 1 0 0 0
21286 2 1 0 2 1 0
21352 2 1 1 1 0 0
21395 0 1 0 0 1 0
21417 1 1 0 2 1 0
21423 0 1 0 0 1 0
21426 1 1 0 1 1 0
21433 0 1 0 0 1 0
21435 0 1 0 0 0 0
21436 1 1 0 0 0 0
21439 1 1 0 2 1 0
21446 1 1 0 0 0 0
21448 0 1 1 2 0 0
21453 2 1 0 0 1 0
30042 2 0 1 0 0 1
30080 0 0 1 0 1 0
301003 1 0 1 0 0 0
301009 0 0 1 2 1 0
301017 0 0 1 0 0 0
30154 1 0 1 0 1 0
30176 0 0 1 0 1 0
30210 0 0 1 0 1 0
30239 1 0 1 0 1 0
30311 0 0 0 0 0 1
30382 0 0 1 2 1 0
30387 0 0 1 0 1 0
30415 0 0 1 0 1 0
30428 0 0 1 0 0 0
30479 0 0 1 0 0 1
30485 0 0 1 2 1 0
30493 2 0 1 2 1 0
30519 0 0 1 0 1 0
30532 0 0 1 0 1 0
30541 0 0 1 0 1 0
30567 1 0 0 0 0 0
30569 2 0 1 1 1 0
30578 0 0 1 0 0 1
30579 1 0 1 0 0 0
30596 1 0 1 1 1 0
30597 1 0 1 1 0 0
30618 0 0 1 0 0 0
30622 1 0 1 1 1 0
30627 1 0 1 2 0 0
30648 2 0 0 0 1 0
30655 0 0 1 0 0 1
30658 0 0 1 0 1 0
30667 0 0 1 0 1 0
30678 1 0 1 0 0 0
30701 0 0 1 0 0 0
30703 2 0 1 1 0 0
30710 0 0 1 2 0 0
30713 1 0 0 1 1 1
30716 0 0 0 0 1 0
30721 0 0 0 0 0 1
30723 0 0 1 0 1 0
30724 2 0 1 2 1 0
30733 1 0 0 1 0 0
30734 0 0 1 0 0 0
30736 2 0 0 1 1 1
30737 0 0 1 0 0 0
30740 0 0 1 0 1 0
30742 2 0 1 0 1 0
30743 0 0 1 0 1 0
30745 2 0 0 0 1 0
30754 1 0 1 0 1 0
30758 1 0 0 0 1 0
30764 0 0 1 0 0 1
30765 2 0 0 0 0 0
30769 2 0 0 1 1 0
30772 0 0 1 0 1 0
30774 0 0 0 0 1 0
30784 2 0 1 0 0 0
30786 1 0 1 0 1 0
30787 0 0 0 0 1 0
30789 1 0 1 0 1 0
30800 0 0 1 0 0 0
30801 1 0 1 0 1 0
30803 1 0 1 0 1 0
30806 1 0 1 0 1 0
30817 0 0 1 2 0 0
30819 2 0 1 0 1 1
30822 0 0 1 0 1 0
30823 0 0 1 2 1 0
30834 0 0 0 0 0 0
30836 0 0 1 0 1 0
30837 1 0 1 0 1 0
30840 0 0 1 0 1 0
30841 1 0 1 0 0 0
30844 0 0 1 0 1 0
30845 0 0 1 0 0 0
30847 1 0 1 0 0 0
30848 0 0 1 0 1 0
30850 0 0 1 0 1 0
30856 1 0 0 0 1 0
30858 0 0 1 0 0 0
30860 0 0 0 0 1 0
30862 1 0 1 1 1 0
30864 0 0 0 2 0 0
30867 0 0 1 0 1 0
30869 0 0 1 0 1 0
30887 0 0 1 0 1 0
30900 1 0 0 1 1 0
30913 2 0 0 0 1 0
30914 1 0 0 0 0 0
30922 2 0 0 2 1 0
30923 0 0 1 2 1 0
30927 1 0 1 0 0 1
30929 0 0 1 2 1 0
30933 0 0 1 2 1 0
30940 0 0 1 0 1 0
30943 1 0 1 2 1 0
30945 0 0 0 2 0 0
30951 1 0 0 0 0 0
30964 0 0 0 2 1 0
30969 0 0 1 0 1 0
30979 2 0 0 0 1 0
30980 1 0 0 0 0 0
30982 1 0 0 1 1 0
30990 1 0 1 1 1 0
30991 1 0 1 0 1 1
30999 0 0 1 0 1 0
31056 1 1 0 2 1 0
31068 1 1 0 1 0 0
31108 2 1 0 2 1 0
31168 1 1 1 0 0 0
31191 0 1 1 0 0 0
31229 0 1 1 0 0 1
31263 0 1 0 0 1 0
31281 1 1 1 0 0 0
31340 1 1 1 0 1 0
31375 0 1 0 0 1 0
31401 0 1 1 0 0 1
31480 1 1 1 1 1 0
31501 1 1 0 2 1 0
31514 0 1 0 2 0 0
31518 1 1 0 2 1 0
31532 0 0 1 2 1 0
31543 2 1 1 1 1 0
31588 0 1 0 0 1 0
31590 0 0 1 0 1 0
31591 2 1 0 1 1 0
31595 0 1 0 0 1 0
31596 1 1 0 0 0 0
31598 1 1 0 0 1 0
31599 0 1 0 0 0 0
31605 0 1 1 0 0 0
31612 2 1 0 0 1 0
31615 2 1 0 0 0 0
31628 1 1 0 0 1 0
31640 2 1 0 1 1 0


 

测试数据:testdata.txt

10009 1 0 0 1 0 1
10025 0 0 1 2 0 0
20035 0 0 1 0 0 1
20053 1 0 0 0 0 0
30627 1 0 1 2 0 0
30648 2 0 0 0 1 0


 

梯度上升法求解Logistic回归,布布扣,bubuko.com

梯度上升法求解Logistic回归

标签:c   style   class   blog   code   tar   

原文地址:http://blog.csdn.net/acdreamers/article/details/27688361

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!