码迷,mamicode.com
首页 > 其他好文 > 详细

统计学习方法笔记2:感知机

时间:2015-04-29 23:04:34      阅读:186      评论:0      收藏:0      [点我收藏+]

标签:

1.感知机:二类分类的线性模型,输入为实例的特征向量,输出为某类别,取+1和-1.目的在求出将训练数据进行线性划分的分离超平面,导入基于误分类的损失函数,利用梯度下降法对损失函数进行极小化求得感知机模型。

2.感知机模型:

  技术分享,sign为符号函数,w为权值或权向量,b为偏置。

  其几何解释:技术分享对应一个越平面,w为法向量,b截距。

3.感知机学习策略

  1)数据集的线性可分性:

  数据集技术分享,存在一个超平面S将数据集正实例和负实例完全分布在平面两侧。

  2)策略:

  任意点到超平面的距离:技术分享技术分享

  总距离:技术分享,在不考虑1/||w||就得到了感知机学习损失函数。

  损失函数:技术分享,y={-1,1}

4.学习算法

  1)极小化问题的解:

  技术分享

  a.选取初值w。,b。

  b.在训练集中选取数据(x,y)

  c.如果y(w•x+b)≤0

  技术分享

  技术分享

  d.迭代,直至数据集中没有误分点,即L(w,b)=0.

  

  1 include <iostream>
  2 #include <vector>
  3 #include <algorithm>
  4 #define random(x) (rand() % (x))
  5 
  6 double dot_product(std::vector<double>& a, std::vector<double>& b){
  7     if (a.size() != b.size()) return 0;
  8     double res = 0;
  9     for (int i = 0; i < a.size(); ++i){
 10         res += a[i] * b[i];
 11     }
 12     return res;
 13 }
 14 
 15 class Preception{
 16 public:
 17     Preception(int iters = 100, int learnRate = 1, double initw = 0, double initb = 0){
 18         iterators = iters;
 19         w.push_back(initw);
 20         b = initb;
 21         step = learnRate;
 22 
 23     }
 24     ~Preception(){
 25         w.clear();
 26         b = 0;
 27     }
 28     bool train(std::vector<std::vector<double> >& train_x, std::vector<int>& train_y){
 29         if (train_x.size() != train_y.size()) return false;
 30         initWeight(train_x[0].size());
 31 
 32         for (int iter = 0; iter < iterators; ++iter){
 33             bool flag = true;
 34             for (int i = 0; i < train_x.size();){
 35                 if ((dot_product(w, train_x[i]) + b)*(double)train_y[i] <= 0){
 36                     update(train_x[i], train_y[i]);
 37                     flag = false;
 38                 }
 39                 else{
 40                     ++i;
 41                 }
 42             }
 43             if (flag) return true;
 44         }
 45         return false;
 46     }
 47 
 48     std::vector<int> predict(std::vector<std::vector<double> >& data_x){
 49         std::vector<int> ret;
 50         for (int i = 0; i < data_x.size(); ++i){
 51             ret.push_back(predict(data_x[i]));
 52 
 53         }
 54         return ret;
 55 
 56     }
 57 
 58     int predict(std::vector<double>& x){
 59         return dot_product(x, w) + b > 0 ? 1 : -1;
 60 
 61     }
 62     void printPreceptronModel(){
 63         std::cout << "原始形式感知机模型:f(x)=sign(";
 64         for (int i = 0; i < w.size(); ++i){
 65             if (i) std::cout << "+";
 66             if (w[i] != 1) std::cout << w[i];
 67             std::cout << "x" << i + 1;
 68         }
 69         if (b > 0) std::cout << "+";
 70         std::cout << b << ")" << std::endl;
 71     }
 72 private:
 73     void initWeight(int size){
 74         for (int i = 1; i < size; ++i){
 75             w.push_back(w[0]);
 76         }
 77     }
 78 
 79     void update(std::vector<double>& x, double y){
 80         for (int i = 0; i < w.size(); ++i){
 81             w[i] += step*y*x[i];
 82 
 83         }
 84         b += step*y;
 85 
 86         for (int i = 0; i < w.size(); ++i)
 87             std::cout << w[i] << ",";
 88         std::cout << std::endl;
 89 
 90         std::cout << b << std::endl;
 91 
 92     }
 93 
 94 private:
 95     int iterators;
 96     std::vector<double> w;
 97     double b;
 98     double step;
 99 };
100 
101 int main(){
102     std::vector<std::vector<double> >test_x(3);
103     test_x[0].push_back(3); test_x[0].push_back(3);
104     test_x[1].push_back(4); test_x[1].push_back(3);
105     test_x[2].push_back(1); test_x[2].push_back(1);
106     std::vector<int> test_y(3);
107     test_y[0] = 1;
108     test_y[1] = 1;
109     test_y[2] = -1;
110 
111     Preception *model = new Preception();
112     model->train(test_x, test_y);
113     model->printPreceptronModel();
114 }

 

统计学习方法笔记2:感知机

标签:

原文地址:http://www.cnblogs.com/xp12/p/4467519.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!