码迷,mamicode.com
首页 > 编程语言 > 详细

【EM】C++代码实现

时间:2014-09-19 17:09:15      阅读:929      评论:0      收藏:0      [点我收藏+]

标签:style   blog   color   io   os   ar   for   数据   div   

看了原理和比人的代码后,终于自己写了一个EM的实现。

我从网上找了一些身高性别的数据,用EM算法通过身高信息来识别性别。

实现的效果还行,正确率有84% (初始数据 男生170 女生160 方差都是10)

                                 79%  (初始数据 男生165 女生150 方差都是10)

正确率与初始值有关。

/*
试图用EM算法来根据输入的身高来区分性别
*/

#include<iostream>
#include<fstream>
#include<algorithm>
#include<vector>
using namespace std;

#define PI 3.14159
#define max(x,y) (x > y ? x : y)

typedef struct FLOAT2
{
    float f1;
    float f2;
}FLOAT2;
typedef struct Gaussian
{
    float mean;
    float var;
}Gaussian;

typedef struct EMData
{
    char sex;
    float fHeight;
}EMData;

//获取身高性别数据
int getdata(vector<EMData> &Data)
{
    ifstream fin;
    fin.open("data.txt");
    if(!fin)
    {
        cout<<"error: can‘t open the file."<<endl;
        return -1;
    }

    while(!fin.eof())
    {
        char c[10];
        float height;
        fin >> c >> height;
        EMData data;
        data.sex = c[0];
        data.fHeight = height;
        Data.push_back(data);
    }

    return 0;
}

//根据身高数据区分性别, 返回正确率
float predict(vector<EMData> Data)
{
    //设符合正态分布
    Gaussian sex[2];
    float a[2]; //男女生所占百分比
    float t = 1;
    float tlimit = 0.000001; //收敛条件

    //赋初值 下标0表示男生 1表示女生
    sex[0].mean = 180.0;
    sex[0].var = 10.0;
    sex[1].mean = 150.0;
    sex[1].var = 10.0;
    a[0] = 0.5;
    a[1] = 0.5;

    while(t > tlimit)
    {
        Gaussian sex_old[2];
        float a_old[2];
        sex_old[0] = sex[0];
        sex_old[1] = sex[1];
        a_old[0] = a[0];
        a_old[1] = a[1];

        //计算每个样本分别被两个模型抽中的概率
        vector<FLOAT2> px;
    
        vector<EMData>::iterator it;
        for(it = Data.begin(); it < Data.end(); it++)
        {
            FLOAT2 p;
            p.f1 = 1/(sqrt(2 * PI * sex[0].var)) * exp(-(it->fHeight - sex[0].mean) * (it->fHeight - sex[0].mean) / (2 * sex[0].var));
            p.f2 = 1/(sqrt(2 * PI * sex[1].var)) * exp(-(it->fHeight - sex[1].mean) * (it->fHeight - sex[1].mean) / (2 * sex[1].var));
            px.push_back(p);
        }

        //E步
        //计算每个样本属于男生或女生的概率
        vector<FLOAT2>::iterator it2;
        for(it2 = px.begin(); it2 < px.end(); it2++)
        {
            float sum = 0.0;
            (*it2).f1 *= a[0];
            sum += (*it2).f1;
            (*it2).f2 *= a[1];
            sum += (*it2).f2;

            (*it2).f1 = (*it2).f1/sum;
            (*it2).f2 = (*it2).f2/sum;
        }

        //M步
        float sum_male = 0, sum_female = 0;
        float sum_mean_male = 0, sum_mean_female = 0;
        for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)
        {
            sum_male += (*it2).f1;
            sum_female += (*it2).f2;
            sum_mean_male += (*it2).f1 * (it->fHeight);
            sum_mean_female += (*it2).f2 * (it->fHeight);
        }
        //更新a
        a[0] = sum_male/(sum_male + sum_female);
        a[1] = sum_female/(sum_male + sum_female);

        //更新均值
        sex[0].mean = sum_mean_male/ sum_male;
        sex[1].mean = sum_mean_female/ sum_female;

        //更新方差
        float sum_var_male = 0, sum_var_female = 0;
        for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++)
        {
            sum_var_male += (*it2).f1 * ((it->fHeight) - sex[0].mean) * ((it->fHeight) - sex[0].mean);
            sum_var_female += (*it2).f2 * ((it->fHeight) - sex[1].mean) * ((it->fHeight) - sex[1].mean);
        }
        sex[0].var = sum_var_male / sum_male;
        sex[1].var = sum_var_female / sum_female;

        //计算变化率
        t = max((a[0] - a_old[0])/a_old[0], (a[1] - a_old[1])/a_old[1]);
        t = max(t, (sex[0].mean - sex_old[0].mean)/sex_old[0].mean);
        t = max(t, (sex[1].mean - sex_old[1].mean)/sex_old[1].mean);
        t = max(t, (sex[0].var - sex_old[0].var)/sex_old[0].var);
        t = max(t, (sex[1].var - sex_old[1].var)/sex_old[1].var);
    }

    //计算正确率
    int correct_num = 0;
    float correct_rate = 0;
    vector<EMData>::iterator it;
    for(it = Data.begin(); it < Data.end(); it++)
    {
        float p[2];
        char csex;
        for(int i = 0; i < 2; i++)
        {
            p[i] = 1/(sqrt(2 * PI * sex[i].var)) * exp(-(it->fHeight - sex[i].mean) * (it->fHeight - sex[i].mean) / (2 * sex[i].var));
        }

        csex = (p[0] > p[1]) ? m : f;
        if(csex == it->sex)
            correct_num++;
    }

    correct_rate = (float)correct_num / Data.size();
    return correct_rate;
}

int main()
{
    vector<EMData> Data;
    getdata(Data);
    float correct_rate = predict(Data);
    cout << "correct rate = "<< correct_rate << endl;
    return 0;
}

 

数据:data.txt内容

male    164
female    156
male    168
female    160
female    162
male    187
female    162
male    167
female    160.5
female    160
female    158
female    164
female    165
male    174
female    166
female    158
male     162
male    175
male    170
female    161
female    169
female    161
female    160
female    167
male    176
male    169
male    178
male    165
female    155
male    183
male    171
male    179
female    154
male    172
female    172
male    173
male    172
male    175
male    160
male    160
male    160
male    175
male    163
male    181
male    172
male    175
male    175
male    167
male    172
male    169
male    172
male    175
male    172
male    170
male    158
male    167
male    164
male    176
male    182
male    173
male    176
male    163
male    166
male    162
male    169
male    163
male    163
male    176
male    169
male    173
male    163
male    167
male    176
male    168
male    167
male    170
female    155
female    157
female    165
female    156
female    155
female    156
female    160
female    158
female    162
female    162
female    155
female    163
female    160
female    162
female    165
female    159
female    147
female    163
female    157
female    160
female    162
female    158
female    155
female    165
female    161
female    159
female    163
female    158
female    155
female    162
female    157
female    159
female    152
female    156
female    165
female    154
female    156
female    162

 

【EM】C++代码实现

标签:style   blog   color   io   os   ar   for   数据   div   

原文地址:http://www.cnblogs.com/dplearning/p/3981578.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!