码迷,mamicode.com
首页 > 其他好文 > 详细

libsvm代码阅读:关于Solver类分析(二)(转)

时间:2015-04-05 15:49:27      阅读:254      评论:0      收藏:0      [点我收藏+]

标签:

如果你看完了上篇博文的伪代码,那么我们就可以开始谈谈它的源代码了。

下面先贴出它的类定义,一些成员函数的具体实现先忽略。

 

[cpp]   view plain copy 技术分享 技术分享
<EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918  
  2. // Solves:  
  3. //  min 0.5(\alpha^T Q \alpha) + p^T \alpha  
  4. //  
  5. //      y^T \alpha = \delta  
  6. //      y_i = +1 or -1  
  7. //      0 <= alpha_i <= Cp for y_i = 1  
  8. //      0 <= alpha_i <= Cn for y_i = -1  
  9. //  
  10. // Given:  
  11. //  Q, p, y, Cp, Cn, and an initial feasible point \alpha  
  12. //  l is the size of vectors and matrices  
  13. //  eps is the stopping tolerance  
  14. // solution will be put in \alpha, objective value will be put in obj  
  15. //  
  16. class Solver {  
  17. public:  
  18.     Solver() {};  
  19.     virtual ~Solver() {};//用虚析构函数的原因是:保证根据实际运行适当的析构函数  
  20.   
  21.     struct SolutionInfo {  
  22.         double obj;  
  23.         double rho;  
  24.         double upper_bound_p;  
  25.         double upper_bound_n;  
  26.         double r;   // for Solver_NU  
  27.     };  
  28.   
  29.     void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,  
  30.            double *alpha_, double Cp, double Cn, double eps,  
  31.            SolutionInfo* si, int shrinking);  
  32. protected:  
  33.     int active_size;//计算时实际参加运算的样本数目,经过shrink处理后,该数目小于全部样本数  
  34.     schar *y;       //样本所属类别,该值只能取-1或+1。  
  35.     double *G;      // gradient of objective function = (Q alpha + p)  
  36.     enum { LOWER_BOUND, UPPER_BOUND, FREE };  
  37.     char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE   
  38.     double *alpha;      //  
  39.     const QMatrix *Q;     
  40.     const double *QD;  
  41.     double eps;     //误差限  
  42.     double Cp,Cn;  
  43.     double *p;  
  44.     int *active_set;  
  45.     double *G_bar;      // gradient, if we treat free variables as 0  
  46.     int l;  
  47.     bool unshrink;  // XXX  
  48.     //返回对应于样本的C。设置不同的Cp和Cn是为了处理数据的不平衡  
  49.     double get_C(int i)  
  50.     {  
  51.         return (y[i] > 0)? Cp : Cn;  
  52.     }  
  53.   
  54.     void update_alpha_status(int i)  
  55.     {  
  56.         if(alpha[i] >= get_C(i))  
  57.             alpha_status[i] = UPPER_BOUND;  
  58.         else if(alpha[i] <= 0)  
  59.             alpha_status[i] = LOWER_BOUND;  
  60.         else alpha_status[i] = FREE;  
  61.     }  
  62.     bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }  
  63.     bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }  
  64.     bool is_free(int i) { return alpha_status[i] == FREE; }  
  65.     void swap_index(int i, int j);//交换样本i和j的内容,包括申请的内存的地址  
  66.     void reconstruct_gradient();  //重新计算梯度。  
  67.     virtual int select_working_set(int &i, int &j);//选择工作集  
  68.     virtual double calculate_rho();  
  69.     virtual void do_shrinking();//对样本集做缩减。  
  70. private:  
  71.     bool be_shrunk(int i, double Gmax1, double Gmax2);    
  72. };  

 

下面我们来看看SMO如何选择工作集(working set B),选择的约束如下:

 

[cpp]   view plain copy 技术分享 技术分享
<EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. // return i,j such that  
  2. // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)  
  3. // j: minimizes the decrease of obj value  
  4. //    (if quadratic coefficeint <= 0, replace it with tau)  
  5. //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)  

论文中的公式如下:

 

技术分享

技术分享

技术分享

[cpp]   view plain copy 技术分享 技术分享
<EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. int Solver::select_working_set(int &out_i, int &out_j)  
  2. {  
  3.     // return i,j such that  
  4.     // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)  
  5.     // j: minimizes the decrease of obj value  
  6.     //    (if quadratic coefficeint <= 0, replace it with tau)  
  7.     //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)  
  8. //select i    
  9.     double Gmax = -INF;  
  10.     double Gmax2 = -INF;  
  11.     int Gmax_idx = -1;  
  12.     int Gmin_idx = -1;  
  13.     double obj_diff_min = INF;  
  14.   
  15.     for(int t=0;t<active_size;t++)  
  16.         if(y[t]==+1)    //若类别为1  
  17.         {  
  18.             if(!is_upper_bound(t))//若alpha<C  
  19.                 if(-G[t] >= Gmax)  
  20.                 {  
  21.                     Gmax = -G[t];// -y[t]*G[t]=-1*G[t]  
  22.                     Gmax_idx = t;  
  23.                 }  
  24.         }  
  25.         else  
  26.         {  
  27.             if(!is_lower_bound(t))  
  28.                 if(G[t] >= Gmax)  
  29.                 {  
  30.                     Gmax = G[t];  
  31.                     Gmax_idx = t;  
  32.                 }  
  33.         }  
  34.   
  35.     int i = Gmax_idx;  
  36.     const Qfloat *Q_i = NULL;  
  37.     if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1  
  38.         Q_i = Q->get_Q(i,active_size);  
  39. //select j  
  40.     for(int j=0;j<active_size;j++)  
  41.     {  
  42.         if(y[j]==+1)  
  43.         {  
  44.             if (!is_lower_bound(j))  
  45.             {  
  46.                 double grad_diff=Gmax+G[j];  
  47.                 if (G[j] >= Gmax2)  
  48.                     Gmax2 = G[j];  
  49.                 if (grad_diff > 0)  
  50.                 {  
  51.                     double obj_diff;   
  52.                     double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];  
  53.                     if (quad_coef > 0)  
  54.                         obj_diff = -(grad_diff*grad_diff)/quad_coef;  
  55.                     else  
  56.                         obj_diff = -(grad_diff*grad_diff)/TAU;  
  57.   
  58.                     if (obj_diff <= obj_diff_min)  
  59.                     {  
  60.                         Gmin_idx=j;  
  61.                         obj_diff_min = obj_diff;  
  62.                     }  
  63.                 }  
  64.             }  
  65.         }  
  66.         else  
  67.         {  
  68.             if (!is_upper_bound(j))  
  69.             {  
  70.                 double grad_diff= Gmax-G[j];  
  71.                 if (-G[j] >= Gmax2)  
  72.                     Gmax2 = -G[j];  
  73.                 if (grad_diff > 0)  
  74.                 {  
  75.                     double obj_diff;   
  76.                     double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];  
  77.                     if (quad_coef > 0)  
  78.                         obj_diff = -(grad_diff*grad_diff)/quad_coef;  
  79.                     else  
  80.                         obj_diff = -(grad_diff*grad_diff)/TAU;  
  81.   
  82.                     if (obj_diff <= obj_diff_min)  
  83.                     {  
  84.                         Gmin_idx=j;  
  85.                         obj_diff_min = obj_diff;  
  86.                     }  
  87.                 }  
  88.             }  
  89.         }  
  90.     }  
  91.   
  92.     if(Gmax+Gmax2 < eps)  
  93.         return 1;  
  94.   
  95.     out_i = Gmax_idx;  
  96.     out_j = Gmin_idx;  
  97.     return 0;  
  98. }  

配合上面几个公式看,这段代码还是很清晰了。

 

下面来看看它的构造函数,这个构造函数是solver类的核心。这个算法也结合上一篇博文的algorithm2来看。其中要注意的是get_Q是获取核函数。

 

[cpp]   view plain copy 技术分享 技术分享
<EMBED id=ZeroClipboardMovie_4 height=18 name=ZeroClipboardMovie_4 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=4&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,  
  2.            double *alpha_, double Cp, double Cn, double eps,  
  3.            SolutionInfo* si, int shrinking)  
  4. {  
  5.     this->l = l;  
  6.     this->Q = &Q;  
  7.     QD=Q.get_QD();//这个是获取核函数(如果分类的话在SVC_Q中定义)  
  8.   
  9.     clone(p, p_,l);  
  10.     clone(y, y_,l);  
  11.     clone(alpha,alpha_,l);  
  12.   
  13.     this->Cp = Cp;  
  14.     this->Cn = Cn;  
  15.     this->eps = eps;  
  16.     unshrink = false;  
  17.   
  18.     // initialize alpha_status  
  19.     {  
  20.         alpha_status = new char[l];  
  21.         for(int i=0;i<l;i++)  
  22.             update_alpha_status(i);  
  23.     }  
  24.   
  25.     // initialize active set (for shrinking)  
  26.     {  
  27.         active_set = new int[l];  
  28.         for(int i=0;i<l;i++)  
  29.             active_set[i] = i;  
  30.         active_size = l;  
  31.     }  
  32.   
  33.     // initialize gradient  
  34.     {  
  35.         G = new double[l];  
  36.         G_bar = new double[l];  
  37.         int i;  
  38.         for(i=0;i<l;i++)  
  39.         {  
  40.             G[i] = p[i];  
  41.             G_bar[i] = 0;  
  42.         }  
  43.         for(i=0;i<l;i++)  
  44.             if(!is_lower_bound(i))  
  45.             {  
  46.                 const Qfloat *Q_i = Q.get_Q(i,l);  
  47.                 double alpha_i = alpha[i];  
  48.                 int j;  
  49.                 for(j=0;j<l;j++)  
  50.                     G[j] += alpha_i*Q_i[j];  
  51.                 if(is_upper_bound(i))  
  52.                     for(j=0;j<l;j++)  
  53.                         G_bar[j] += get_C(i) * Q_i[j]; //这里见文献LIBSVM: A Library for SVM公式(33)  
  54.             }  
  55.     }  
  56.   
  57.     // optimization step  
  58.   
  59.     int iter = 0;  
  60.     int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);  
  61.     int counter = min(l,1000)+1;  
  62.       
  63.     while(iter < max_iter)  
  64.     {  
  65.         // show progress and do shrinking  
  66.   
  67.         if(--counter == 0)  
  68.         {  
  69.             counter = min(l,1000);  
  70.             if(shrinking) do_shrinking();    
  71.             info(".");  
  72.         }  
  73.   
  74.         int i,j;  
  75.         if(select_working_set(i,j)!=0)  
  76.         {  
  77.             // reconstruct the whole gradient  
  78.             reconstruct_gradient();  
  79.             // reset active set size and check  
  80.             active_size = l;  
  81.             info("*");  
  82.             if(select_working_set(i,j)!=0)  
  83.                 break;  
  84.             else  
  85.                 counter = 1;    // do shrinking next iteration  
  86.         }  
  87.           
  88.         ++iter;  
  89.   
  90.         // update alpha[i] and alpha[j], handle bounds carefully  
  91.           
  92.         const Qfloat *Q_i = Q.get_Q(i,active_size);  
  93.         const Qfloat *Q_j = Q.get_Q(j,active_size);  
  94.   
  95.         double C_i = get_C(i);  
  96.         double C_j = get_C(j);  
  97.   
  98.         double old_alpha_i = alpha[i];  
  99.         double old_alpha_j = alpha[j];  
  100.   
  101.         if(y[i]!=y[j])  
  102.         {  
  103.             double quad_coef = QD[i]+QD[j]+2*Q_i[j];  
  104.             if (quad_coef <= 0)  
  105.                 quad_coef = TAU;  
  106.             double delta = (-G[i]-G[j])/quad_coef;  
  107.             double diff = alpha[i] - alpha[j];  
  108.             alpha[i] += delta;  
  109.             alpha[j] += delta;  
  110.               
  111.             if(diff > 0)  
  112.             {  
  113.                 if(alpha[j] < 0)  
  114.                 {  
  115.                     alpha[j] = 0;  
  116.                     alpha[i] = diff;  
  117.                 }  
  118.             }  
  119.             else  
  120.             {  
  121.                 if(alpha[i] < 0)  
  122.                 {  
  123.                     alpha[i] = 0;  
  124.                     alpha[j] = -diff;  
  125.                 }  
  126.             }  
  127.             if(diff > C_i - C_j)  
  128.             {  
  129.                 if(alpha[i] > C_i)  
  130.                 {  
  131.                     alpha[i] = C_i;  
  132.                     alpha[j] = C_i - diff;  
  133.                 }  
  134.             }  
  135.             else  
  136.             {  
  137.                 if(alpha[j] > C_j)  
  138.                 {  
  139.                     alpha[j] = C_j;  
  140.                     alpha[i] = C_j + diff;  
  141.                 }  
  142.             }  
  143.         }  
  144.         else  
  145.         {  
  146.             double quad_coef = QD[i]+QD[j]-2*Q_i[j];  
  147.             if (quad_coef <= 0)  
  148.                 quad_coef = TAU;  
  149.             double delta = (G[i]-G[j])/quad_coef;  
  150.             double sum = alpha[i] + alpha[j];  
  151.             alpha[i] -= delta;  
  152.             alpha[j] += delta;  
  153.   
  154.             if(sum > C_i)  
  155.             {  
  156.                 if(alpha[i] > C_i)  
  157.                 {  
  158.                     alpha[i] = C_i;  
  159.                     alpha[j] = sum - C_i;  
  160.                 }  
  161.             }  
  162.             else  
  163.             {  
  164.                 if(alpha[j] < 0)  
  165.                 {  
  166.                     alpha[j] = 0;  
  167.                     alpha[i] = sum;  
  168.                 }  
  169.             }  
  170.             if(sum > C_j)  
  171.             {  
  172.                 if(alpha[j] > C_j)  
  173.                 {  
  174.                     alpha[j] = C_j;  
  175.                     alpha[i] = sum - C_j;  
  176.                 }  
  177.             }  
  178.             else  
  179.             {  
  180.                 if(alpha[i] < 0)  
  181.                 {  
  182.                     alpha[i] = 0;  
  183.                     alpha[j] = sum;  
  184.                 }  
  185.             }  
  186.         }  
  187.   
  188.         // update G  
  189.   
  190.         double delta_alpha_i = alpha[i] - old_alpha_i;  
  191.         double delta_alpha_j = alpha[j] - old_alpha_j;  
  192.           
  193.         for(int k=0;k<active_size;k++)  
  194.         {  
  195.             G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;  
  196.         }  
  197.   
  198.         // update alpha_status and G_bar  
  199.   
  200.         {  
  201.             bool ui = is_upper_bound(i);  
  202.             bool uj = is_upper_bound(j);  
  203.             update_alpha_status(i);  
  204.             update_alpha_status(j);  
  205.             int k;  
  206.             if(ui != is_upper_bound(i))  
  207.             {  
  208.                 Q_i = Q.get_Q(i,l);  
  209.                 if(ui)  
  210.                     for(k=0;k<l;k++)  
  211.                         G_bar[k] -= C_i * Q_i[k];  
  212.                 else  
  213.                     for(k=0;k<l;k++)  
  214.                         G_bar[k] += C_i * Q_i[k];  
  215.             }  
  216.   
  217.             if(uj != is_upper_bound(j))  
  218.             {  
  219.                 Q_j = Q.get_Q(j,l);  
  220.                 if(uj)  
  221.                     for(k=0;k<l;k++)  
  222.                         G_bar[k] -= C_j * Q_j[k];  
  223.                 else  
  224.                     for(k=0;k<l;k++)  
  225.                         G_bar[k] += C_j * Q_j[k];  
  226.             }  
  227.         }  
  228.     }  
  229.   
  230.     if(iter >= max_iter)  
  231.     {  
  232.         if(active_size < l)  
  233.         {  
  234.             // reconstruct the whole gradient to calculate objective value  
  235.             reconstruct_gradient();  
  236.             active_size = l;  
  237.             info("*");  
  238.         }  
  239.         fprintf(stderr,"\nWARNING: reaching max number of iterations\n");  
  240.     }  
  241.   
  242.     // calculate rho  
  243.   
  244.     si->rho = calculate_rho();  
  245.   
  246.     // calculate objective value  
  247.     {  
  248.         double v = 0;  
  249.         int i;  
  250.         for(i=0;i<l;i++)  
  251.             v += alpha[i] * (G[i] + p[i]);  
  252.   
  253.         si->obj = v/2;  
  254.     }  
  255.   
  256.     // put back the solution  
  257.     {  
  258.         for(int i=0;i<l;i++)  
  259.             alpha_[active_set[i]] = alpha[i];  
  260.     }  
  261.   
  262.     // juggle everything back  
  263.     /*{ 
  264.         for(int i=0;i<l;i++) 
  265.             while(active_set[i] != i) 
  266.                 swap_index(i,active_set[i]); 
  267.                 // or Q.swap_index(i,active_set[i]); 
  268.     }*/  
  269.   
  270.     si->upper_bound_p = Cp;  
  271.     si->upper_bound_n = Cn;  
  272.   
  273.     info("\noptimization finished, #iter = %d\n",iter);  
  274.   
  275.     delete[] p;  
  276.     delete[] y;  
  277.     delete[] alpha;  
  278.     delete[] alpha_status;  
  279.     delete[] active_set;  
  280.     delete[] G;  
  281.     delete[] G_bar;  
  282. }  

libsvm代码阅读:关于Solver类分析(二)(转)

标签:

原文地址:http://www.cnblogs.com/Miliery/p/4394149.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!