码迷,mamicode.com
首页 > 编程语言 > 详细

AdaBoost的java实现

时间:2015-06-11 12:36:24      阅读:421      评论:0      收藏:0      [点我收藏+]

标签:

目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。

这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。

关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。

bagging和boosting的区别
bagging:是指在原始数据上通过放回抽样,抽出和原始数据大小相等的新数据集(这个性质说明新数据集存在重复的值,而原始数据部分数据值不会出现在新数据集中),并重复该过程选择N个新数据集,这样通过N个分类器对这个N个数据集进行分类,最后选择分类器投票结果中最多类别作为最后的分类结果。
boosting:相比bagging,boosting像是一种串行,bagging是一种并行的,bagging可以对于N个数据集通过N个分类器同时进行分类,并且每个分类器的权重是一样的,但是boosting则相反,boosting是利用一个数据集依次由每个分类器进行分类,而确定每个分类器的权重是加大正确率高的分类器的权重,减少正确率低的分类器的权重。同时为了提高准确率,每次会降低被正确分类的样本的权重,提高没有正确分类的样本的权重。这样做其实比较符合人的决策过程,就是要多训练自己容易做错的题型,并且要多听取正确性高的老师的意见。
 
那么AdaBoost的主要的两个过程就是提高错误分类的样本权重和提高正确率高的分类器的权重。
算法的步骤:
输入:训练集T,弱学习分类器(这里是一个节点的决策树)
输出:最终的分类器G
1 先初始化样本权重值,D1={W11...W1n}W1i=1/n
2 根据样本权重D1以及决策树求分类误差率,并求的最小的误差率em,以及该决策树
  em=技术分享
3 计算该分类器的权重
  技术分享可以看出,误差率越小的,其权重越大
4 更新各个样本的权重,Dm+1,(用公式编辑器好麻烦。。。 )
  技术分享
其中Zm是规范化银子:
  技术分享
5 构建基本分类器
  F(X)=技术分享
6 计算该分类器下的误差率,如果小于某个阈值就停止,否则从第二步开始迭代
 
终于不用打公式了。。。。
附上代码:
  1 import java.io.BufferedReader;
  2 import java.io.FileInputStream;
  3 import java.io.IOException;
  4 import java.io.InputStreamReader;
  5 import java.util.ArrayList;
  6 
  7 class Stump{
  8     public int dim;
  9     public double thresh;
 10     public String condition;
 11     public double error;
 12     public ArrayList<Integer> labelList;
 13     double factor;
 14     
 15     public String toString(){
 16         return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList;
 17     }
 18 }
 19 
 20 class Utils{
 21     //加载数据集
 22     public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{
 23         ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>();
 24         FileInputStream fis=new FileInputStream(filename);
 25         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
 26         BufferedReader br=new BufferedReader(isr);
 27         String line="";
 28         
 29         while((line=br.readLine())!=null){
 30             ArrayList<Double> data=new ArrayList<Double>();
 31             String[] s=line.split(" ");
 32             
 33             for(int i=0;i<s.length-1;i++){
 34                 data.add(Double.parseDouble(s[i]));
 35             }
 36             dataSet.add(data);
 37         }
 38         return  dataSet;
 39     }
 40     
 41     //加载类别
 42     public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{
 43         ArrayList<Integer> labelSet=new ArrayList<Integer>();
 44         
 45         FileInputStream fis=new FileInputStream(filename);
 46         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
 47         BufferedReader br=new BufferedReader(isr);
 48         String line="";
 49         
 50         while((line=br.readLine())!=null){
 51             String[] s=line.split(" ");
 52             labelSet.add(Integer.parseInt(s[s.length-1]));
 53         }
 54         return labelSet;
 55     }
 56     //测试用的
 57     public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){
 58         for(ArrayList<Double> data:dataSet){
 59             System.out.println(data);
 60         }
 61     }
 62     //获取最大值,用于求步长
 63     public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){
 64         double max=-9999.0;
 65         for(ArrayList<Double> data:dataSet){
 66             if(data.get(index)>max){
 67                 max=data.get(index);
 68             }
 69         }
 70         return max;
 71     }
 72     //获取最小值,用于求步长
 73     public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){
 74         double min=9999.0;
 75         for(ArrayList<Double> data:dataSet){
 76             if(data.get(index)<min){
 77                 min=data.get(index);
 78             }
 79         }
 80         return min;
 81     }
 82     
 83     //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别
 84     public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){
 85         ArrayList<Integer> labelList=new ArrayList<Integer>();
 86         if(condition.compareTo("lt")==0){
 87             for(ArrayList<Double> data:dataSet){
 88                 if(data.get(feature)<=thresh){
 89                     labelList.add(1);
 90                 }else{
 91                     labelList.add(-1);
 92                 }
 93             }
 94         }else{
 95             for(ArrayList<Double> data:dataSet){
 96                 if(data.get(feature)>=thresh){
 97                     labelList.add(1);
 98                 }else{
 99                     labelList.add(-1);
100                 }
101             }
102         }
103         return labelList;
104     }
105     //求预测类别与真实类别的加权误差
106     public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){
107         double error=0;
108         
109         int n=real.size();
110 
111         for(int i=0;i<fake.size();i++){
112             if(fake.get(i)!=real.get(i)){
113                 error+=weights.get(i);
114 
115             }
116         }
117         
118         return error;
119     }
120     //构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。
121     public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){
122         int featureNum=dataSet.get(0).size();
123         
124         int rowNum=dataSet.size();
125         Stump stump=new Stump();
126         double minError=999.0;
127         System.out.println("第"+n+"次迭代");
128         for(int i=0;i<featureNum;i++){
129             double min=getMin(dataSet,i);
130             double max=getMax(dataSet,i);    
131             double step=(max-min)/(rowNum);
132             for(double j=min-step;j<=max+step;j=j+step){
133                 String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类
134                 for(String condition:conditions){
135                     ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition);
136                     
137                     double error=Utils.getError(labelList,labelSet,weights);
138                     if(error<minError){
139                         minError=error;
140                         stump.dim=i;
141                         stump.thresh=j;
142                         stump.condition=condition;
143                         stump.error=minError;
144                         stump.labelList=labelList;
145                         stump.factor=0.5*(Math.log((1-error)/error));
146                     }
147                     
148                 }
149             }
150             
151         }
152         
153         return stump;
154     }
155     
156     public static ArrayList<Double> getInitWeights(int n){
157         double weight=1.0/n;
158         ArrayList<Double> weights=new ArrayList<Double>();
159         for(int i=0;i<n;i++){
160             weights.add(weight);
161         }
162         return weights;
163     }
164     //更新样本权值
165     public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){
166         double Z=0;
167         ArrayList<Double> newWeights=new ArrayList<Double>();
168         int row=labelList.size();
169         double e=Math.E;
170         double factor=stump.factor;
171         for(int i=0;i<row;i++){
172             Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i));
173         }
174         
175         
176         for(int i=0;i<row;i++){
177             double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z;
178             newWeights.add(weight);
179         }
180         return newWeights;
181     }
182     //对加权误差累加
183     public static ArrayList<Double> InitAccWeightError(int n){
184         ArrayList<Double> accError=new ArrayList<Double>();
185         for(int i=0;i<n;i++){
186             accError.add(0.0);
187         }
188         return accError;
189     }
190     
191     public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){
192         ArrayList<Integer> t=stump.labelList;
193         double factor=stump.factor;
194         ArrayList<Double> newAccError=new ArrayList<Double>();
195         for(int i=0;i<t.size();i++){
196             double a=accerror.get(i)+factor*t.get(i);
197             newAccError.add(a);
198         }
199         return newAccError;
200     }
201     
202     public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){
203         ArrayList<Integer> a=new ArrayList<Integer>();
204         int wrong=0;
205         for(int i=0;i<accError.size();i++){
206             if(accError.get(i)>0){
207                 if(labelList.get(i)==-1){
208                     wrong++;
209                 }
210             }else if(labelList.get(i)==1){
211                 wrong++;
212             }
213         }
214         double error=wrong*1.0/accError.size();
215         return error;
216     }
217     
218     public static void showStumpList(ArrayList<Stump> G){
219         for(Stump s:G){
220             System.out.println(s);
221             System.out.println(" ");
222         }
223     }
224 }
225 
226 
227 public class Adaboost {
228 
229     /**
230      * @param args
231      * @throws IOException 
232      */
233     
234     public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){
235         int row=labelList.size();
236         ArrayList<Double> weights=Utils.getInitWeights(row);
237         ArrayList<Stump> G=new ArrayList<Stump>();
238         ArrayList<Double> accError=Utils.InitAccWeightError(row);
239         int n=1;
240         while(true){
241             Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树
242             G.add(stump);
243             weights=Utils.updateWeights(stump,labelList,weights);//更新权值
244             accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了
245             double error=Utils.calErrorRate(accError,labelList);
246             if(error<0.001){
247                 break;
248             }
249             n++;
250         }
251         return G;
252     }
253     
254     public static void main(String[] args) throws IOException {
255         // TODO Auto-generated method stub
256         String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt";
257         ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file);
258         ArrayList<Integer> labelSet=Utils.loadLabelSet(file);
259         ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet);
260         Utils.showStumpList(G);
261         System.out.println("finished");
262     }
263 
264 }

这里的数据采用的是统计学习方法中的数据

0 1
1 1
2 1
3 -1
4 -1
5 -1
6 1
7 1
8 1
9 -1

这里是单个特征的,也可以是多维数据,例如

1.0 2.1 1
2.0 1.1 1
1.3 1.0 -1
1.0 1.0 -1
2.0 1.0 1

 

AdaBoost的java实现

标签:

原文地址:http://www.cnblogs.com/sunrye/p/4568647.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!