码迷,mamicode.com
首页 > 编程语言 > 详细

AdaBoost算法

时间:2016-03-17 21:27:26      阅读:305      评论:0      收藏:0      [点我收藏+]

标签:

理论知识可参考:《统计学习方法》 (李航 著) 第八章

简单代码实现:

技术分享
 1 from numpy import *
 2 import matplotlib.pyplot as plt
 3 
 4 def loadSimpData():
 5     dataMat = matrix([[1,2.1],
 6         [2,1.1],
 7         [1.3,1],
 8         [1,1],
 9         [2,1]])
10     classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
11     return dataMat, classLabels
12 
13 def stumpClassify(dataMatrix, dimen, threshVal, threshIneq):
14     retArray = ones((shape(dataMatrix)[0], 1))
15     if threshIneq == lt:
16         retArray[dataMatrix[:, dimen] <= threshVal] = -1.0
17     else:
18         retArray[dataMatrix[:, dimen] > threshVal] = -1.0
19     return retArray
20 
21 def buildStump(dataArr, classLabels, D):
22     dataMatrix = mat(dataArr); labelMat = mat(classLabels).T
23     m, n = shape(dataMatrix)
24     numSteps = 10.0; bestStump = {}; bestClasEst = mat(zeros((m, 1)))
25     minError = inf
26     for i in range(n):
27         rangeMin = dataMatrix[:, i].min(); rangeMax = dataMatrix[:,i].max();
28         stepSize = (rangeMax-rangeMin)/numSteps
29         for j in range(-1, int(numSteps)+1):
30             for inequal in [lt, gt]:
31                 threshVal = (rangeMin + float(j) * stepSize)
32                 predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)
33                 errArr = mat(ones((m, 1)))
34                 errArr[predictedVals == labelMat] = 0
35                 weightedError = D.T * errArr
36                 #print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError))
37                 if weightedError < minError:
38                     minError = weightedError
39                     bestClasEst = predictedVals.copy()
40                     bestStump[dim] = i
41                     bestStump[thresh] = threshVal
42                     bestStump[ineq] = inequal
43     return bestStump, minError, bestClasEst
44 
45 def adaBoostTrainDS(dataArr, classLabels, numIt = 40):
46     weakClassArr = []
47     m = shape(dataArr)[0]
48     D = mat(ones((m, 1))/m)
49     aggClassEst = mat(zeros((m, 1)))
50     for i in range(numIt):
51         bestStump, error, classEst = buildStump(dataArr, classLabels, D)
52         print("D:", D.T)
53         alpha = float(0.5*log((1.0-error)/max(error, 1e-16)))
54         bestStump[alpha] = alpha
55         weakClassArr.append(bestStump)
56         print("classEst:", classEst)
57         expon = multiply(-1*alpha*mat(classLabels).T, classEst)
58         D = multiply(D, exp(expon))
59         D = D/D.sum()
60         aggClassEst += alpha*classEst
61         print("aggClassEst:", aggClassEst.T)
62         aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T, ones((m, 1)))
63         errorRate = aggErrors.sum()/m
64         print("total error:", errorRate, "\n")
65         if errorRate == 0.0: break
66     return weakClassArr
67 
68 dataMat, classLabels = loadSimpData()
69 D = mat(ones((5, 1))/5)
70 classifierArray = adaBoostTrainDS(dataMat, classLabels, 9)
71 print(classifierArray)
View Code

 

AdaBoost算法

标签:

原文地址:http://www.cnblogs.com/JustForCS/p/5289146.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!