标签:
目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。
这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。
关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。
1 import java.io.BufferedReader; 2 import java.io.FileInputStream; 3 import java.io.IOException; 4 import java.io.InputStreamReader; 5 import java.util.ArrayList; 6 7 class Stump{ 8 public int dim; 9 public double thresh; 10 public String condition; 11 public double error; 12 public ArrayList<Integer> labelList; 13 double factor; 14 15 public String toString(){ 16 return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList; 17 } 18 } 19 20 class Utils{ 21 //加载数据集 22 public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{ 23 ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>(); 24 FileInputStream fis=new FileInputStream(filename); 25 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 26 BufferedReader br=new BufferedReader(isr); 27 String line=""; 28 29 while((line=br.readLine())!=null){ 30 ArrayList<Double> data=new ArrayList<Double>(); 31 String[] s=line.split(" "); 32 33 for(int i=0;i<s.length-1;i++){ 34 data.add(Double.parseDouble(s[i])); 35 } 36 dataSet.add(data); 37 } 38 return dataSet; 39 } 40 41 //加载类别 42 public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{ 43 ArrayList<Integer> labelSet=new ArrayList<Integer>(); 44 45 FileInputStream fis=new FileInputStream(filename); 46 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 47 BufferedReader br=new BufferedReader(isr); 48 String line=""; 49 50 while((line=br.readLine())!=null){ 51 String[] s=line.split(" "); 52 labelSet.add(Integer.parseInt(s[s.length-1])); 53 } 54 return labelSet; 55 } 56 //测试用的 57 public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){ 58 for(ArrayList<Double> data:dataSet){ 59 System.out.println(data); 60 } 61 } 62 //获取最大值,用于求步长 63 public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){ 64 double max=-9999.0; 65 for(ArrayList<Double> data:dataSet){ 66 if(data.get(index)>max){ 67 max=data.get(index); 68 } 69 } 70 return max; 71 } 72 //获取最小值,用于求步长 73 public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){ 74 double min=9999.0; 75 for(ArrayList<Double> data:dataSet){ 76 if(data.get(index)<min){ 77 min=data.get(index); 78 } 79 } 80 return min; 81 } 82 83 //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别 84 public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){ 85 ArrayList<Integer> labelList=new ArrayList<Integer>(); 86 if(condition.compareTo("lt")==0){ 87 for(ArrayList<Double> data:dataSet){ 88 if(data.get(feature)<=thresh){ 89 labelList.add(1); 90 }else{ 91 labelList.add(-1); 92 } 93 } 94 }else{ 95 for(ArrayList<Double> data:dataSet){ 96 if(data.get(feature)>=thresh){ 97 labelList.add(1); 98 }else{ 99 labelList.add(-1); 100 } 101 } 102 } 103 return labelList; 104 } 105 //求预测类别与真实类别的加权误差 106 public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){ 107 double error=0; 108 109 int n=real.size(); 110 111 for(int i=0;i<fake.size();i++){ 112 if(fake.get(i)!=real.get(i)){ 113 error+=weights.get(i); 114 115 } 116 } 117 118 return error; 119 } 120 //构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。 121 public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){ 122 int featureNum=dataSet.get(0).size(); 123 124 int rowNum=dataSet.size(); 125 Stump stump=new Stump(); 126 double minError=999.0; 127 System.out.println("第"+n+"次迭代"); 128 for(int i=0;i<featureNum;i++){ 129 double min=getMin(dataSet,i); 130 double max=getMax(dataSet,i); 131 double step=(max-min)/(rowNum); 132 for(double j=min-step;j<=max+step;j=j+step){ 133 String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类 134 for(String condition:conditions){ 135 ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); 136 137 double error=Utils.getError(labelList,labelSet,weights); 138 if(error<minError){ 139 minError=error; 140 stump.dim=i; 141 stump.thresh=j; 142 stump.condition=condition; 143 stump.error=minError; 144 stump.labelList=labelList; 145 stump.factor=0.5*(Math.log((1-error)/error)); 146 } 147 148 } 149 } 150 151 } 152 153 return stump; 154 } 155 156 public static ArrayList<Double> getInitWeights(int n){ 157 double weight=1.0/n; 158 ArrayList<Double> weights=new ArrayList<Double>(); 159 for(int i=0;i<n;i++){ 160 weights.add(weight); 161 } 162 return weights; 163 } 164 //更新样本权值 165 public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){ 166 double Z=0; 167 ArrayList<Double> newWeights=new ArrayList<Double>(); 168 int row=labelList.size(); 169 double e=Math.E; 170 double factor=stump.factor; 171 for(int i=0;i<row;i++){ 172 Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i)); 173 } 174 175 176 for(int i=0;i<row;i++){ 177 double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z; 178 newWeights.add(weight); 179 } 180 return newWeights; 181 } 182 //对加权误差累加 183 public static ArrayList<Double> InitAccWeightError(int n){ 184 ArrayList<Double> accError=new ArrayList<Double>(); 185 for(int i=0;i<n;i++){ 186 accError.add(0.0); 187 } 188 return accError; 189 } 190 191 public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){ 192 ArrayList<Integer> t=stump.labelList; 193 double factor=stump.factor; 194 ArrayList<Double> newAccError=new ArrayList<Double>(); 195 for(int i=0;i<t.size();i++){ 196 double a=accerror.get(i)+factor*t.get(i); 197 newAccError.add(a); 198 } 199 return newAccError; 200 } 201 202 public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){ 203 ArrayList<Integer> a=new ArrayList<Integer>(); 204 int wrong=0; 205 for(int i=0;i<accError.size();i++){ 206 if(accError.get(i)>0){ 207 if(labelList.get(i)==-1){ 208 wrong++; 209 } 210 }else if(labelList.get(i)==1){ 211 wrong++; 212 } 213 } 214 double error=wrong*1.0/accError.size(); 215 return error; 216 } 217 218 public static void showStumpList(ArrayList<Stump> G){ 219 for(Stump s:G){ 220 System.out.println(s); 221 System.out.println(" "); 222 } 223 } 224 } 225 226 227 public class Adaboost { 228 229 /** 230 * @param args 231 * @throws IOException 232 */ 233 234 public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){ 235 int row=labelList.size(); 236 ArrayList<Double> weights=Utils.getInitWeights(row); 237 ArrayList<Stump> G=new ArrayList<Stump>(); 238 ArrayList<Double> accError=Utils.InitAccWeightError(row); 239 int n=1; 240 while(true){ 241 Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树 242 G.add(stump); 243 weights=Utils.updateWeights(stump,labelList,weights);//更新权值 244 accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了 245 double error=Utils.calErrorRate(accError,labelList); 246 if(error<0.001){ 247 break; 248 } 249 n++; 250 } 251 return G; 252 } 253 254 public static void main(String[] args) throws IOException { 255 // TODO Auto-generated method stub 256 String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt"; 257 ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file); 258 ArrayList<Integer> labelSet=Utils.loadLabelSet(file); 259 ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet); 260 Utils.showStumpList(G); 261 System.out.println("finished"); 262 } 263 264 }
这里的数据采用的是统计学习方法中的数据
0 1 1 1 2 1 3 -1 4 -1 5 -1 6 1 7 1 8 1 9 -1
这里是单个特征的,也可以是多维数据,例如
1.0 2.1 1 2.0 1.1 1 1.3 1.0 -1 1.0 1.0 -1 2.0 1.0 1
标签:
原文地址:http://www.cnblogs.com/sunrye/p/4568647.html