码迷,mamicode.com
首页 > 其他好文 > 详细

《统计学习方法》第八章,提升方法

时间:2019-08-02 00:00:25      阅读:102      评论:0      收藏:0      [点我收藏+]

标签:alpha   size   mat   new   抖动   nat   round   too   real   

? 使用 Ada Boosting 方法提升若干个弱分类器的效果

● 代码,每个感知机仅训练原数据集 trainRatio = 30% 的数据,然后进行调整和提升

  1 import numpy as np
  2 import matplotlib.pyplot as plt
  3 from mpl_toolkits.mplot3d import Axes3D
  4 from mpl_toolkits.mplot3d.art3d import Poly3DCollection
  5 from matplotlib.patches import Rectangle
  6 
  7 dataSize = 500
  8 trainDataRatio = 0.3
  9 ita = 0.3
 10 defaultTrainRatio = 0.3                                                         # 故意减少简单分类器的训练轮数
 11 randomSeed = 103
 12 
 13 def myColor(x):                                                                 # 颜色函数
 14     r = np.select([x < 1/2, x < 3/4, x <= 1, True],[0, 4 * x - 2, 1, 0])
 15     g = np.select([x < 1/4, x < 3/4, x <= 1, True],[4 * x, 1, 4 - 4 * x, 0])
 16     b = np.select([x < 1/4, x < 1/2, x <= 1, True],[1, 2 - 4 * x, 0, 0])
 17     return [r**2,g**2,b**2]
 18 
 19 def dataSplit(dataX, dataY, part):                                              # 将数据集分割为训练集和测试集
 20     return dataX[:part,:],dataY[:part], dataX[part:,:], dataY[part:]
 21 
 22 def function(x,para):                                                           # 连续回归函数,用于画图
 23     return np.sum(x * para[0]) - para[1]                                        # 注意是减号
 24 
 25 def judgeWeak(x, para):                                                         # 弱分类判别函数
 26     return np.sign(function(x, para))
 27 
 28 def judgeStrong(x, paraList , alpha):                                           # 强分类判别函数,调用弱分类判别函数进行线性加和
 29     return np.sign( np.sum([ judgeWeak(x, paraList[i]) * alpha[i] for i in range(len(paraList)) ]) )
 30 
 31 def targetIndex(x, xList):                                                      # 二分查找 xList 中不大于 x 的最大索引
 32     lp = 0
 33     rp = len(xList) - 1
 34     mp = mp = (lp + rp) >> 1
 35     while lp < mp:
 36         if(xList[mp] > x):
 37             rp = mp
 38         else:
 39             lp = mp
 40         mp = (lp + rp) >> 1
 41     return mp
 42 
 43 def createData(dim, count = dataSize):                                          # 创建数据
 44     np.random.seed(randomSeed)
 45     X = np.random.rand(count, dim)
 46     if dim == 1:
 47         Y = (X > 0.5).astype(int).flatten() * 2 - 1
 48     else:
 49         Y = ((3 - 2 * dim) * X[:,0] + 2 * np.sum(X[:,1:], 1) > 0.5).astype(int) * 2 - 1
 50     print( "dim = %d, dataSize = %d, class1Ratio = %f"%(dim, count, np.sum((Y == 1).astype(int)) / count) )
 51     return X, Y
 52 
 53 def perceptron(dataX, dataY, weight, trainRatio = defaultTrainRatio):           # 单层感知机,只训练 dataX 中占比为 trainRatio 的数据
 54     count, dim = np.shape(dataX)
 55     xE = np.concatenate((dataX, -np.ones(count)[:,np.newaxis]), axis = 1)
 56     w = np.zeros(dim + 1)
 57     accWeight = np.cumsum(weight)                                               # 累加分布列用于随机选取
 58     finishFlag = False
 59     for i in range(int(count * trainRatio)):
 60         j = targetIndex(np.random.rand(), accWeight)                            # 依分布列随机抽取一个样本进行训练
 61         w += ita * (dataY[j] - np.sign(np.sum(xE[j] * w))) * xE[j]
 62     return (w[:-1],w[-1])
 63 
 64 def adaBoost(dataX, dataY, weakCount):                                          # 提升训练
 65     count, dim = np.shape(dataX)
 66     weight = np.ones(count) / count                                             # 样本权重
 67     paraList = []                                                               # 弱分类器的系数
 68     alpha = np.zeros(weakCount)                                                 # 弱分类器的权重
 69     for i in range(weakCount):
 70         para = perceptron(dataX, dataY, weight)                                 # 每次训练后检查训练集的分类情况,调整弱分类器权重和样本权重
 71         trainResult = [ judgeWeak(i, para) for i in dataX ]
 72         trainErrorRatio = np.sum( (np.array(trainResult) != dataY).astype(int) * weight )
 73         paraList.append(para)
 74         alpha[i] = np.log(1 / (trainErrorRatio + 1e-8) - 1) / 2
 75         weight *= np.exp( -alpha[i] * dataY * trainResult )
 76         weight /= np.sum(weight)
 77     return paraList, alpha
 78 
 79 def test(dim, weakCount):                                                       # 测试函数
 80     allX, allY = createData(dim)
 81     trainX, trainY, testX,testY = dataSplit(allX, allY, int(dataSize * trainDataRatio))
 82 
 83     paraList, alpha = adaBoost(trainX, trainY, weakCount)
 84 
 85     testResult = [ judgeStrong(i, paraList, alpha) for i in testX ]
 86     errorRatio = np.sum( (np.array(testResult) != testY).astype(int)**2 ) / (dataSize*(1-trainDataRatio))
 87     print( "dim = %d, weakCount = %d, errorRatio = %f"%(dim, weakCount, round(errorRatio,4)) )
 88     for i in range(weakCount):
 89         print(alpha[i] , "\t\t", paraList[i])
 90 
 91     if dim >= 4:                                                                # 4维以上不画图,只输出测试错误率
 92         return
 93 
 94     classP = [ [],[] ]
 95     errorP = []
 96     for i in range(len(testX)):
 97         if testResult[i] != testY[i]:
 98             if dim == 1:
 99                 errorP.append(np.array([testX[i], int(testY[i]+1)>>1]))
100             else:
101                 errorP.append(np.array(testX[i]))
102         else:
103             classP[int(testResult[i]+1)>>1].append(testX[i])
104     errorP = np.array(errorP)
105     classP = [ np.array(classP[0]), np.array(classP[1]) ]
106 
107     fig = plt.figure(figsize=(10, 8))
108     if dim == 1:
109         plt.xlim(0.0,1.0)
110         plt.ylim(-0.25,1.25)
111         for i in range(2):
112             if(len(classP[i])) > 0:
113                 plt.scatter(classP[i], np.ones(len(classP[i])) * i, color = myColor(i/2), s = 8, label = "class" + str(i))
114         if len(errorP) != 0:
115             plt.scatter(errorP[:,0], errorP[:,1],color = myColor(1), s = 16,label = "errorData")
116 
117         plt.plot([0.5, 0.5], [-0.25, 1.25], color = [0.5,0.25,0],label = "realBoundary")
118         plt.text(0.2, 1.1, "realBoundary: 2x = 1\nerrorRatio = " + str(round(errorRatio,4)),119             size=15, ha="center", va="center", bbox=dict(boxstyle="round", ec=(1., 0.5, 0.5), fc=(1., 1., 1.)))
120         R = [ Rectangle((0,0),0,0, color = myColor(i / 2)) for i in range(2) ] + [ Rectangle((0,0),0,0, color = myColor(1)), Rectangle((0,0),0,0, color = [0.5,0.25,0]) ]
121         plt.legend(R, [ "class" + str(i) for i in range(2) ] + ["errorData", "realBoundary"], loc=[0.81, 0.2], ncol=1, numpoints=1, framealpha = 1)
122 
123     if dim == 2:
124         plt.xlim(-0.1, 1.1)
125         plt.ylim(-0.1, 1.1)
126         for i in range(2):
127             if(len(classP[i])) > 0:
128                 plt.scatter(classP[i][:,0], classP[i][:,1], color = myColor(i/2), s = 8, label = "class" + str(i))
129         if len(errorP) != 0:
130             plt.scatter(errorP[:,0], errorP[:,1], color = myColor(1), s = 16, label = "errorData")
131         plt.plot([0,1], [1/4,3/4], color = [0.5,0.25,0], label = "realBoundary")
132         plt.text(0.78, 1.02, "realBoundary: -x + 2y = 1\nerrorRatio = " + str(round(errorRatio,4)), 133             size = 15, ha="center", va="center", bbox=dict(boxstyle="round", ec=(1., 0.5, 0.5), fc=(1., 1., 1.)))
134         R = [ Rectangle((0,0),0,0, color = myColor(i / 2)) for i in range(2) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ]
135         plt.legend(R, [ "class" + str(i) for i in range(2) ] + ["errorData"], loc=[0.84, 0.012], ncol=1, numpoints=1, framealpha = 1)
136 
137     if dim == 3:
138         ax = Axes3D(fig)
139         ax.set_xlim3d(0.0, 1.0)
140         ax.set_ylim3d(0.0, 1.0)
141         ax.set_zlim3d(0.0, 1.0)
142         ax.set_xlabel(X, fontdict={size: 15, color: k})
143         ax.set_ylabel(Y, fontdict={size: 15, color: k})
144         ax.set_zlabel(Z, fontdict={size: 15, color: k})
145         for i in range(2):
146             if(len(classP[i])) > 0:
147                 ax.scatter(classP[i][:,0], classP[i][:,1], classP[i][:,2], color = myColor(i/2), s = 8, label = "class" + str(i))
148         if len(errorP) != 0:
149             ax.scatter(errorP[:,0], errorP[:,1],errorP[:,2], color = myColor(1), s = 8, label = "errorData")
150         v = [(0, 0, 0.25), (0, 0.25, 0), (0.5, 1, 0), (1, 1, 0.75), (1, 0.75, 1), (0.5, 0, 1)]
151         f = [[0,1,2,3,4,5]]
152         poly3d = [[v[i] for i in j] for j in f]
153         ax.add_collection3d(Poly3DCollection(poly3d, edgecolor = k, facecolors = [0.5,0.25,0,0.5], linewidths=1))
154         ax.text3D(0.75, 0.92, 1.15, "realBoundary: -3x + 2y +2z = 1\nerrorRatio = " + str(round(errorRatio,4)), 155             size = 12, ha="center", va="center", bbox=dict(boxstyle="round", ec=(1, 0.5, 0.5), fc=(1, 1, 1)))
156         R = [ Rectangle((0,0),0,0, color = myColor(i / 2)) for i in range(2) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ]
157         plt.legend(R, [ "class" + str(i) for i in range(2) ] + ["errorData"], loc=[0.84, 0.012], ncol=1, numpoints=1, framealpha = 1)
158 
159     fig.savefig("R:\\dim" + str(dim) + "kind2" + "weakCount" + str(weakCount) + ".png")
160     plt.close()
161 
162 if __name__ == __main__:
163     test(1, 1)                                                                  # 不同维数和弱分类器数的组合
164     test(1, 2)
165     test(1, 3)
166     test(1, 4)
167     test(2, 1)
168     test(2, 2)
169     test(2, 3)
170     test(2, 4)
171     test(3, 1)
172     test(3, 2)
173     test(3, 3)
174     test(3, 4)
175     test(4, 1)
176     test(4, 2)
177     test(4, 3)
178     test(4, 4)

● 输出结果,随着使用的弱分类器数量的增多,预测精度逐渐上升。低维情况不明显,少数的弱分类器就已经达到了较好的精度,高维情况中,精度上升会抖动,被分类的点在分类结果中也会抖动。

dim = 1, dataSize = 500, class1Ratio = 0.492000
dim = 1, weakCount = 1, errorRatio = 0.320000
0.34657356777997284              (array([1.67141915]), 0.29999999999999993)
dim = 1, dataSize = 500, class1Ratio = 0.492000
dim = 1, weakCount = 2, errorRatio = 0.002900
0.34657356777997284              (array([1.67141915]), 0.29999999999999993)
2.6466513960316105               (array([0.59811356]), 0.3)
dim = 1, dataSize = 500, class1Ratio = 0.492000
dim = 1, weakCount = 3, errorRatio = 0.002900
0.34657356777997284              (array([1.67141915]), 0.29999999999999993)
2.6466513960316105               (array([0.59811356]), 0.3)
1.154062035731127                (array([0.70689064]), 0.29999999999999993)
dim = 1, dataSize = 500, class1Ratio = 0.492000
dim = 1, weakCount = 4, errorRatio = 0.002900
0.34657356777997284              (array([1.67141915]), 0.29999999999999993)
2.6466513960316105               (array([0.59811356]), 0.3)
1.154062035731127                (array([0.70689064]), 0.29999999999999993)
0.41049029622924904              (array([0.65816408]), 0.29999999999999993)
dim = 2, dataSize = 500, class1Ratio = 0.520000
dim = 2, weakCount = 1, errorRatio = 0.165700
0.7581737108087062               (array([-0.5342485 ,  0.85301855]), 0.3)
dim = 2, dataSize = 500, class1Ratio = 0.520000
dim = 2, weakCount = 2, errorRatio = 0.140000
0.7581737108087062               (array([-0.5342485 ,  0.85301855]), 0.3)
1.1603017192470149               (array([-0.23046473,  1.17772171]), 0.29999999999999993)
dim = 2, dataSize = 500, class1Ratio = 0.520000
dim = 2, weakCount = 3, errorRatio = 0.082900
0.7581737108087062               (array([-0.5342485 ,  0.85301855]), 0.3)
1.1603017192470149               (array([-0.23046473,  1.17772171]), 0.29999999999999993)
1.366866794214113                (array([-0.86403595,  1.29893022]), 0.3)
dim = 2, dataSize = 500, class1Ratio = 0.520000
dim = 2, weakCount = 4, errorRatio = 0.082900
0.7581737108087062               (array([-0.5342485 ,  0.85301855]), 0.3)
1.1603017192470149               (array([-0.23046473,  1.17772171]), 0.29999999999999993)
1.366866794214113                (array([-0.86403595,  1.29893022]), 0.3)
-0.07595124913479236             (array([-0.71435958,  1.09996259]), 0.3)
dim = 3, dataSize = 500, class1Ratio = 0.544000
dim = 3, weakCount = 1, errorRatio = 0.334300
0.4236489063840784               (array([-1.88583778,  1.00159772,  0.23076269]), 0.3)
dim = 3, dataSize = 500, class1Ratio = 0.544000
dim = 3, weakCount = 2, errorRatio = 0.097100
0.4236489063840784               (array([-1.88583778,  1.00159772,  0.23076269]), 0.3)
1.2147383422658522               (array([-1.11207425,  0.87462922,  1.16116403]), 0.29999999999999993)
dim = 3, dataSize = 500, class1Ratio = 0.544000
dim = 3, weakCount = 3, errorRatio = 0.074300
0.4236489063840784               (array([-1.88583778,  1.00159772,  0.23076269]), 0.3)
1.2147383422658522               (array([-1.11207425,  0.87462922,  1.16116403]), 0.29999999999999993)
1.4030555888409086               (array([-0.90813279,  0.97916935,  0.44726373]), 0.3)
dim = 3, dataSize = 500, class1Ratio = 0.544000
dim = 3, weakCount = 4, errorRatio = 0.088600
0.4236489063840784               (array([-1.88583778,  1.00159772,  0.23076269]), 0.3)
1.2147383422658522               (array([-1.11207425,  0.87462922,  1.16116403]), 0.29999999999999993)
1.4030555888409086               (array([-0.90813279,  0.97916935,  0.44726373]), 0.3)
0.298249916659031                (array([-0.92372522,  1.11109598,  0.9864088 ]), -0.30000000000000004)
dim = 4, weakCount = 1, errorRatio = 0.271400
0.6328331575281093               (array([-1.55413592,  1.59665079,  0.46795061,  1.01271949]), 0.29999999999999993)
dim = 4, dataSize = 500, class1Ratio = 0.484000
dim = 4, weakCount = 2, errorRatio = 0.271400
0.6328331575281093               (array([-1.55413592,  1.59665079,  0.46795061,  1.01271949]), 0.29999999999999993)
0.4566505516305031               (array([-2.06478282, -0.07030723,  0.28072944,  0.50215833]), 0.30000000000000004)
dim = 4, dataSize = 500, class1Ratio = 0.484000
dim = 4, weakCount = 3, errorRatio = 0.271400
0.6328331575281093               (array([-1.55413592,  1.59665079,  0.46795061,  1.01271949]), 0.29999999999999993)
0.4566505516305031               (array([-2.06478282, -0.07030723,  0.28072944,  0.50215833]), 0.30000000000000004)
0.1788001854725199               (array([-1.57793113,  1.11981   ,  0.68428309,  0.48606427]), -0.3)
dim = 4, dataSize = 500, class1Ratio = 0.484000
dim = 4, weakCount = 4, errorRatio = 0.177100
0.6328331575281093               (array([-1.55413592,  1.59665079,  0.46795061,  1.01271949]), 0.29999999999999993)
0.4566505516305031               (array([-2.06478282, -0.07030723,  0.28072944,  0.50215833]), 0.30000000000000004)
0.1788001854725199               (array([-1.57793113,  1.11981   ,  0.68428309,  0.48606427]), -0.3)
0.8838043587493469               (array([-1.46314889,  0.7044062 ,  0.47142833,  0.2926442 ]), 0.3)

● 画图,行:数据维数,列:分别使用 1 ~ 4 个弱分类器

技术图片技术图片技术图片技术图片

技术图片技术图片技术图片技术图片

技术图片技术图片技术图片技术图片

《统计学习方法》第八章,提升方法

标签:alpha   size   mat   new   抖动   nat   round   too   real   

原文地址:https://www.cnblogs.com/cuancuancuanhao/p/11285960.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!