标签:
1、引言
决策树是建立在信息论基础之上,对数据进行分类挖掘的一种方法。其思想是,通过一批已知的训练数据建立一棵决策树,然后利用建好的决策树,对数据进行预测。决策树的建立过程可以看成是数据规则的生成过程。由于基于决策树的分类方法结构简单,本身就是人们能够理解的规则。其次,决策树方法计算复杂度不大,分类效率高,能够处理大数据量的训练集;最后,决策树方法的分类精度较高,对噪声数据有较好的健壮性,符合一般系统的要求。说了这么多,可能还不是太了解决策树,用一个例子来说明吧。
套用俗语,决策树分类的思想类似于找对象。现想象一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话:
女儿:多大年纪了?
母亲:26。
女儿:长的帅不帅?
母亲:挺帅的。
女儿:收入高不?
母亲:不算很高,中等情况。
女儿:是公务员不?
母亲:是,在税务局上班呢。
女儿:那好,我去见见。
这个女孩的决策过程就是典型的分类树决策。相当于通过年龄、长相、收入和是否公务员对将男人分为两个类别:见和不见。假设这个女孩对男人的要求是:30岁以下、长相中等以上并且是高收入者或中等以上收入的公务员,那么这个可以用下图表示女孩的决策逻辑:
也就是说,对未知的选项都可以归类到已知的选项分类类别中。
2、决策树描述
决策树,又称为判定树,是一种类似二叉树或多叉树的树结构。树中的每个非叶节点(包括根节点)对应于训练样本集中一个非类别属性的测试,非叶节点的每个分支对应属性的一个测试结果,每个叶子节点则代表一个类或类分布。从根节点到叶子节点的一条路径形成一条分类规则。决策树可以很方便地转化为分类规则,是一种非常直观的分类模式表示形式。决策树方法的起源是概念学习系统CLS,然后发展到ID3方法而为高潮,最后演化为能处理连续属性的C4.5。有名的决策树方法还有CART和Assistant。是应用最广的归纳推理算法之一。
决策树学习是一种归纳学习方法,当前国际上最有影响的示例学习方法首推的应当是R.Quinlan提出的ID3算法,其前身是概念学习系统CLS。ID3算法是所有可能决策树空间中一种自顶向下、贪婪的搜索方法,以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样例。
在决策树构造中,如何选取一个条件属性作为形成决策树的节点是建树的核心。一般情况下,选取的属性能最大程度反映训练样本集的分类特征。ID3算法作为决策构造中的经典算法,引入了信息论的方法,应用信息论中的熵的概念,采用信息增益作为选择属性的标准来对训练样本集进行划分,选取信息增益最大的属性作为当前节点。计算信息增益还要涉及三个概念:信息熵、信息增益和信息条件熵。
信息熵
信息熵也称为香农熵,是随机变量的期望。度量信息的不确定程度。信息的熵越大,信息就越不容易搞清楚。处理信息就是为了把信息搞清楚,就是熵减少的过程。
计算香农熵的Python代码为:
#计算给定数据集的香农熵 def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob*log(prob,2) return shannonEnt
信息条件熵
(以上公式为属性A的信息条件熵)
信息增益
用于度量属性A降低样本集合X熵的贡献大小。信息增益越大,越适于对X分类。
3、ID3算法
ID3算法是决策树算法的一种。想了解什么是ID3算法之前,我们得先明白一个概念:奥卡姆剃刀。
ID3算法(Iterative Dichotomiser 3 迭代二叉树3代)是一个由Ross Quinlan发明的用于决策树的算法。这个算法便是建立在上述所介绍的奥卡姆剃刀的基础上:越是小型的决策树越优于大的决策树(be simple简单理论)。尽管如此,该算法也不是总是生成最小的树形结构,而是一个启发式算法。
OK,从信息论知识中我们知道,期望信息越小,信息增益越大,从而纯度越高。ID3算法的核心思想就是以信息增益度量属性选择,选择分裂后信息增益(很快,由下文你就会知道信息增益又是怎么一回事)最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策树空间。
所以,ID3的思想便是:
这形成了对合格决策树的贪婪搜索,也就是算法从不回溯重新考虑以前的选择。
寻找最佳属性的Python代码:
#选择最好的数据集划分方式 def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0])-1 baseEntropy = calcShannonEnt(dataSet) #计算香农熵 bestInfoGain = 0.0;bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntroy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prop = len(subDataSet)/float(len(dataSet)) newEntroy += prop * calcShannonEnt(subDataSet) #计算条件信息熵 infoGain = baseEntropy – newEntroy #信息增益 if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature
我们通过一个具体的例子来讲解ID3算法。主要通过两个代码文件实现。ID3决策树算法的相关操作放在文件trees.py中
# -*- coding: utf-8 -*- ‘‘‘ Created on 2015年7月27日 @author: pcithhb ‘‘‘ from math import log import operator #计算给定数据集的香农熵 def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob*log(prob,2) return shannonEnt #按照给定特征划分数据集 #dataSet:待划分的数据集 #axis:划分数据集的特征--数据的第几列 #value:需要返回的特征值 def splitDataSet(dataSet,axis,value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] #获取从第0列到特征列的数据 reducedFeatVec.extend(featVec[axis+1:]) #获取从特征列之后的数据 retDataSet.append(reducedFeatVec) return retDataSet #选择最好的数据集划分方式 def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0])-1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0;bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntroy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prop = len(subDataSet)/float(len(dataSet)) newEntroy += prop * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntroy if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature # def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys():classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classList.iteritems(),key=operator.itemgetter(1),reverse=True)#利用operator操作键值排序字典 return sortedClassCount[0][0] #创建树的函数 def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree #创建数据集 def createDataSetFromTXT(filename): dataSet = []; labels = [] fr = open(filename) linenumber=0 for line in fr.readlines(): line = line.strip() listFromLine = line.strip().split() lineset = [] for cel in listFromLine: lineset.append(cel) if(linenumber==0): labels=lineset else: dataSet.append(lineset) linenumber = linenumber+1 return dataSet,labels
决策树计算图形化相关操作放在treePlotter.py文件中
# -*- coding: utf-8 -*- ‘‘‘ Created on 2015年7月27日 @author: pcithhb ‘‘‘ import matplotlib.pyplot as plt decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") #获取叶节点的数目 def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__==‘dict‘:#测试节点的数据是否为字典,以此判断是否为叶节点 numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs #获取树的层数 def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__==‘dict‘:#测试节点的数据是否为字典,以此判断是否为叶节点 thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth #绘制节点 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=‘axes fraction‘, xytext=centerPt, textcoords=‘axes fraction‘, va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) #绘制连接线 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) #绘制树结构 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__==‘dict‘:#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it‘s a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #创建决策树图形 def createPlot(inTree): fig = plt.figure(1, facecolor=‘white‘) fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), ‘‘) plt.show()
假设在核电站运行过程中,判断某类事故发生与不发生,而事故时由各种部件的故障特征表现出来的,为了简单,我们假设训练集包含10条元素(存放于文件dataset.txt)。其中T1-T6表示各部件的值,Y1表示某种事故,1-表示发生,0-表示不发生
我们先通过从文件加载到数据集中,然后计算数据集的信息熵
# -*- coding: utf-8 -*- ‘‘‘ Created on 2015年7月27日 @author: pcithhb ‘‘‘ import trees import treePlotter if __name__ == ‘__main__‘: pass myDat,labels = trees.createDataSetFromTXT("dataset.txt") shan = trees.calcShannonEnt(myDat) print shan
结果为:0.881290899231
然后通过计算信息增益,得到第一次最佳的分割属性:
col = trees.chooseBestFeatureToSplit(myDat) print col
结果为:4,意味着最佳的分割属性为T5.
最后通过构建决策树,
Tree = trees.createTree(myDat, labels) print Tree treePlotter.createPlot(Tree)
结果为:{‘T5‘: {‘0.25‘: {‘T1‘: {‘0.5‘: ‘0‘, ‘0.75‘: ‘1‘}}, ‘0.5‘: ‘0‘, ‘0.75‘: ‘0‘}}
图形化的结果图为:
标签:
原文地址:http://www.cnblogs.com/hantan2008/p/4674097.html