码迷,mamicode.com
首页 > 编程语言 > 详细

机器学习实战笔记(Python实现)-02-决策树

时间:2016-12-15 17:44:31      阅读:780      评论:0      收藏:0      [点我收藏+]

标签:btree   try   tor   ever   math   完整   data   support   app   

属原创文章,欢迎转载,但请注明出处:http://www.cnblogs.com/hemiy/p/6165759.html 谢谢!

代码及数据-->https://github.com/Wellat/MLaction

1、算法概述及实现

1.1 算法特点

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据

缺点:可能会产生过度匹配问题

适用数据类型:数值型和标称型

1.2 构造决策树

在构造决策树时,需要解决的第一个问题就是,评估当前数据集上哪个特征在划分数据分类时起决定性作用。本书使用ID3算法划分数据集,即通过对比选择不同特征下数据集的信息增益和香农熵来确定最优划分特征。香农熵的定义如下:

技术分享

1.2.1 计算香农熵:

 1 from math import log
 2 import operator
 3 
 4 def createDataSet():
 5     ‘‘‘
 6     产生测试数据
 7     ‘‘‘
 8     dataSet = [[1, 1, yes],
 9                [1, 1, yes],
10                [1, 0, no],
11                [0, 1, no],
12                [0, 1, no]]
13     labels = [no surfacing,flippers]    
14     return dataSet, labels
15 
16 def calcShannonEnt(dataSet):
17     ‘‘‘
18     计算给定数据集的香农熵
19     ‘‘‘
20     numEntries = len(dataSet)
21     labelCounts = {}
22     #统计每个类别出现的次数,保存在字典labelCounts中
23     for featVec in dataSet: 
24         currentLabel = featVec[-1]
25         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
26         labelCounts[currentLabel] += 1 #如果当前键值不存在,则扩展字典并将当前键值加入字典
27     shannonEnt = 0.0
28     for key in labelCounts:
29         #使用所有类标签的发生频率计算类别出现的概率
30         prob = float(labelCounts[key])/numEntries
31         #用这个概率计算香农熵
32         shannonEnt -= prob * log(prob,2) #取2为底的对数
33     return shannonEnt
34 
35 if __name__== "__main__":  
36     ‘‘‘
37     计算给定数据集的香农熵
38     ‘‘‘
39     dataSet,labels = createDataSet()
40     shannonEnt = calcShannonEnt(dataSet)

 

1.2.2 划分数据集

 1 def splitDataSet(dataSet, axis, value):
 2     ‘‘‘
 3     按照给定特征划分数据集
 4     dataSet:待划分的数据集
 5     axis:   划分数据集的第axis个特征
 6     value:  特征的返回值(比较值)
 7     ‘‘‘
 8     retDataSet = []
 9     #遍历数据集中的每个元素,一旦发现符合要求的值,则将其添加到新创建的列表中
10     for featVec in dataSet:
11         if featVec[axis] == value:
12             reducedFeatVec = featVec[:axis]
13             reducedFeatVec.extend(featVec[axis+1:])
14             retDataSet.append(reducedFeatVec)
15             #extend()和append()方法功能相似,但在处理列表时,处理结果完全不同
16             #a=[1,2,3]  b=[4,5,6]
17             #a.append(b) = [1,2,3,[4,5,6]]
18             #a.extend(b) = [1,2,3,4,5,6]
19     return retDataSet

划分数据集的结果如下所示:

技术分享

 

选择最好的数据集划分方式。接下来我们将遍历整个数据集,循环计算香农熵和 splitDataSet() 函数,找到最好的特征划分方式。  

 1 def chooseBestFeatureToSplit(dataSet):
 2     ‘‘‘
 3     选择最好的数据集划分方式
 4     输入:数据集
 5     输出:最优分类的特征的index
 6     ‘‘‘
 7     #计算特征数量
 8     numFeatures = len(dataSet[0]) - 1
 9     baseEntropy = calcShannonEnt(dataSet)
10     bestInfoGain = 0.0; bestFeature = -1
11     for i in range(numFeatures):
12         #创建唯一的分类标签列表
13         featList = [example[i] for example in dataSet]
14         uniqueVals = set(featList)
15         #计算每种划分方式的信息熵
16         newEntropy = 0.0
17         for value in uniqueVals:
18             subDataSet = splitDataSet(dataSet, i, value)
19             prob = len(subDataSet)/float(len(dataSet))
20             newEntropy += prob * calcShannonEnt(subDataSet)     
21         infoGain = baseEntropy - newEntropy
22         #计算最好的信息增益,即infoGain越大划分效果越好
23         if (infoGain > bestInfoGain):
24             bestInfoGain = infoGain
25             bestFeature = i
26     return bestFeature

 

1.2.3 递归构建决策树

目前我们已经学习了从数据集构造决策树算法所需要的子功能模块,其工作原理如下:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。 

由于特征数目并不是在每次划分数据分组时都减少,因此这些算法在实际使用时可能引起一定的问题。目前我们并不需要考虑这个问题,只需要在算法开始运行前计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们通常会采用多数表决的方法决定该叶子节点的分类。

在 trees.py 中增加如下投票表决代码:

 1 import operator
 2 def majorityCnt(classList):
 3     ‘‘‘
 4     投票表决函数
 5     输入classList:标签集合,本例为:[‘yes‘, ‘yes‘, ‘no‘, ‘no‘, ‘no‘]
 6     输出:得票数最多的分类名称
 7     ‘‘‘
 8     classCount={}
 9     for vote in classList:
10         if vote not in classCount.keys(): classCount[vote] = 0
11         classCount[vote] += 1
12     #把分类结果进行排序,然后返回得票数最多的分类结果
13     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
14     return sortedClassCount[0][0]

创建树的函数代码(主函数):

 1 def createTree(dataSet,labels):
 2     ‘‘‘
 3     创建树
 4     输入:数据集和标签列表
 5     输出:树的所有信息
 6     ‘‘‘
 7     # classList为数据集的所有类标签
 8     classList = [example[-1] for example in dataSet]
 9     # 停止条件1:所有类标签完全相同,直接返回该类标签
10     if classList.count(classList[0]) == len(classList): 
11         return classList[0]
12     # 停止条件2:遍历完所有特征时仍不能将数据集划分成仅包含唯一类别的分组,则返回出现次数最多的类标签
13     #
14     if len(dataSet[0]) == 1:
15         return majorityCnt(classList)
16     # 选择最优分类特征
17     bestFeat = chooseBestFeatureToSplit(dataSet)
18     bestFeatLabel = labels[bestFeat]
19     # myTree存储树的所有信息
20     myTree = {bestFeatLabel:{}}
21     # 以下得到列表包含的所有属性值
22     del(labels[bestFeat])
23     featValues = [example[bestFeat] for example in dataSet]
24     uniqueVals = set(featValues)
25     # 遍历当前选择特征包含的所有属性值
26     for value in uniqueVals:
27         subLabels = labels[:]
28         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
29     return myTree

 本例返回 myTree 为字典类型,如下:

{‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: ‘no‘, 1: ‘yes‘}}}}

2、测试分类和存储分类器

利用决策树的分类函数:

 1 def classify(inputTree,featLabels,testVec):
 2     ‘‘‘
 3     决策树的分类函数
 4     inputTree:训练好的树信息
 5     featLabels:标签列表
 6     testVec:测试向量
 7     ‘‘‘
 8     # 在2.7中,找到key所对应的第一个元素为:firstStr = myTree.keys()[0],
 9     # 这在3.4中运行会报错:‘dict_keys‘ object does not support indexing,这是因为python3改变了dict.keys,
10     # 返回的是dict_keys对象,支持iterable 但不支持indexable,
11     # 我们可以将其明确的转化成list,则此项功能在3中应这样实现:
12     firstSides = list(inputTree.keys())
13     firstStr = firstSides[0]
14     secondDict = inputTree[firstStr]
15     # 将标签字符串转换成索引
16     featIndex = featLabels.index(firstStr)
17     key = testVec[featIndex]
18     valueOfFeat = secondDict[key]
19     # 递归遍历整棵树,比较testVec变量中的值与树节点的值,如果到达叶子节点,则返回当前节点的分类标签
20     if isinstance(valueOfFeat, dict): 
21         classLabel = classify(valueOfFeat, featLabels, testVec)
22     else: classLabel = valueOfFeat
23     return classLabel
24 
25 if __name__== "__main__":  
26     ‘‘‘
27     测试分类效果
28     ‘‘‘
29     dataSet,labels = createDataSet()
30     myTree = createTree(dataSet,labels)
31     ans = classify(myTree,labels,[1,0])

决策树模型的存储

 1 def storeTree(inputTree,filename):
 2     ‘‘‘
 3     使用pickle模块存储决策树
 4     ‘‘‘
 5     import pickle
 6     fw = open(filename,wb+)
 7     pickle.dump(inputTree,fw)
 8     fw.close()
 9     
10 def grabTree(filename):
11     ‘‘‘
12     导入决策树模型
13     ‘‘‘
14     import pickle
15     fr = open(filename,rb)
16     return pickle.load(fr)
17 
18 if __name__== "__main__":
19     ‘‘‘
20     存取操作
21     ‘‘‘
22     storeTree(myTree,mt.txt)
23     myTree2 = grabTree(mt.txt)

  

3、使用 Matplotlib 绘制树形图

上节我们已经学习如何从数据集中创建决策树,然而字典的表示形式非常不易于理解,决策树的主要优点就是直观易于理解,如果不能将其直观显示出来,就无法发挥其优势。本节使用 Matplotlib 库编写代码绘制决策树。

创建名为 treePlotter.py 的新文件:

3.1 绘制树节点

 1 import matplotlib.pyplot as plt
 2 
 3 # 定义文本框和箭头格式
 4 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
 5 leafNode = dict(boxstyle="round4", fc="0.8")
 6 arrow_args = dict(arrowstyle="<-")
 7 
 8 # 绘制带箭头的注释
 9 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
10     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords=axes fraction,
11              xytext=centerPt, textcoords=axes fraction,
12              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
13 
14 def createPlot():
15     fig = plt.figure(1, facecolor=grey)
16     fig.clf()
17     # 定义绘图区
18     createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
19     plotNode(a decision node, (0.5, 0.1), (0.1, 0.5), decisionNode)
20     plotNode(a leaf node, (0.8, 0.1), (0.3, 0.8), leafNode)
21     plt.show()
22 
23 if __name__== "__main__":  
24     ‘‘‘
25     绘制树节点
26     ‘‘‘
27     createPlot()

结果如下:??

技术分享

3.2 构造注解树

绘制一棵完整的树需要一些技巧。我们虽然有 x, y 坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确x轴的长度;我们还需要知道树有多少层,来确定y轴的高度。这里另一两个新函数 getNumLeafs() 和 getTreeDepth() ,来获取叶节点的数目和树的层数,createPlot() 为主函数,完整代码如下:

  1 import matplotlib.pyplot as plt
  2 
  3 # 定义文本框和箭头格式
  4 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  5 leafNode = dict(boxstyle="round4", fc="0.8")
  6 arrow_args = dict(arrowstyle="<-")
  7 
  8 # 绘制带箭头的注释
  9 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
 10     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords=axes fraction,
 11              xytext=centerPt, textcoords=axes fraction,
 12              va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
 13 
 14 def createPlot(inTree):
 15     ‘‘‘
 16     绘树主函数
 17     ‘‘‘
 18     fig = plt.figure(1, facecolor=white)
 19     fig.clf()
 20     # 设置坐标轴数据
 21     axprops = dict(xticks=[], yticks=[])
 22     # 无坐标轴
 23     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
 24     # 带坐标轴
 25 #    createPlot.ax1 = plt.subplot(111, frameon=False)
 26     plotTree.totalW = float(getNumLeafs(inTree))
 27     plotTree.totalD = float(getTreeDepth(inTree))
 28     # 两个全局变量plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置,
 29     # 以及放置下一个节点的恰当位置
 30     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
 31     plotTree(inTree, (0.5,1.0), ‘‘)
 32     plt.show()
 33 
 34 
 35 def getNumLeafs(myTree):
 36     ‘‘‘
 37     获取叶节点的数目
 38     ‘‘‘
 39     numLeafs = 0
 40     firstSides = list(myTree.keys())
 41     firstStr = firstSides[0]
 42     secondDict = myTree[firstStr]
 43     for key in secondDict.keys():
 44         # 判断节点是否为字典来以此判断是否为叶子节点
 45         if type(secondDict[key]).__name__==dict:
 46             numLeafs += getNumLeafs(secondDict[key])
 47         else:   numLeafs +=1
 48     return numLeafs
 49 
 50 def getTreeDepth(myTree):
 51     ‘‘‘
 52     获取树的层数
 53     ‘‘‘
 54     maxDepth = 0
 55     firstSides = list(myTree.keys())
 56     firstStr = firstSides[0]
 57     secondDict = myTree[firstStr]
 58     for key in secondDict.keys():
 59         if type(secondDict[key]).__name__==dict:
 60             thisDepth = 1 + getTreeDepth(secondDict[key])
 61         else:   thisDepth = 1
 62         if thisDepth > maxDepth: maxDepth = thisDepth
 63     return maxDepth
 64 
 65 
 66 def plotMidText(cntrPt, parentPt, txtString):
 67     ‘‘‘
 68     计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息
 69     ‘‘‘
 70     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
 71     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
 72     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
 73 
 74 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
 75     # 计算宽与高
 76     numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
 77     depth = getTreeDepth(myTree)
 78     firstSides = list(myTree.keys())
 79     firstStr = firstSides[0]
 80     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
 81     # 标记子节点属性值    
 82     plotMidText(cntrPt, parentPt, nodeTxt)
 83     plotNode(firstStr, cntrPt, parentPt, decisionNode)
 84     secondDict = myTree[firstStr]
 85     # 减少y偏移
 86     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
 87     for key in secondDict.keys():
 88         if type(secondDict[key]).__name__==dict:#test to see if the nodes are dictonaires, if not they are leaf nodes   
 89             plotTree(secondDict[key],cntrPt,str(key))        #recursion
 90         else:   #it‘s a leaf node print the leaf node
 91             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
 92             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
 93             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
 94     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
 95 
 96 
 97 def retrieveTree(i):
 98     ‘‘‘
 99     保存了树的测试数据
100     ‘‘‘
101     listOfTrees =[{no surfacing: {0: no, 1: {flippers: {0: no, 1: yes}}}},
102                   {no surfacing: {0: no, 1: {flippers: {0: {head: {0: no, 1: yes}}, 1: no}}}}
103                   ]
104     return listOfTrees[i]
105 
106 
107 
108 if __name__== "__main__":  
109     ‘‘‘
110     绘制树
111     ‘‘‘
112     createPlot(retrieveTree(1))

测试结果:

技术分享

4、实例:使用决策树预测隐形眼镜类型

 4.1 处理流程

技术分享

技术分享

 

数据格式如下所示,其中最后一列表示类标签:

技术分享 

4.2 Python实现代码 

1 import trees
2 import treePlotter
3 
4 fr = open(lenses.txt)
5 lenses = [inst.strip().split(\t) for inst in fr.readlines()]
6 lensesLabels=[age,prescript,astigmatic,tearRate]
7 lensesTree = trees.createTree(lenses,lensesLabels)
8 treePlotter.createPlot(lensesTree)

产生的决策树:

技术分享

 

本节使用的算法成为ID3,它是一个号的算法但无法直接处理数值型数据,尽管我们可以通过量化的方法将数值型数据转化为标称型数值,但如果存在太多的特征划分,ID3算法仍然会面临其他问题。

机器学习实战笔记(Python实现)-02-决策树

标签:btree   try   tor   ever   math   完整   data   support   app   

原文地址:http://www.cnblogs.com/hemiy/p/6165759.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!