码迷,mamicode.com
首页 > 其他好文 > 详细

决策树

时间:2020-04-03 12:24:45      阅读:81      评论:0      收藏:0      [点我收藏+]

标签:sha   保存   dict   pen   float   元素   one   排列   出现   

from math import log
import operator

"""
函数说明:计算给定数据集的经验熵(香农熵)
Parameters:
    dataSet:数据集
Returns:
    shannonEnt:经验熵
Modify:
    2018-03-12

"""
def calcShannonEnt(dataSet):
    #返回数据集行数
    numEntries=len(dataSet)
    #保存每个标签(label)出现次数的字典
    labelCounts={}
    #对每组特征向量进行统计
    for featVec in dataSet:
        currentLabel=featVec[-1]                     #提取标签信息
        if currentLabel not in labelCounts.keys():   #如果标签没有放入统计次数的字典,添加进去
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1                 #label计数

    shannonEnt=0.0                                   #经验熵
    #计算经验熵
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries      #选择该标签的概率
        shannonEnt-=prob*log(prob,2)                 #利用公式计算
    return shannonEnt                                #返回经验熵



"""
函数说明:按照给定特征划分数据集

Parameters:
    dataSet:待划分的数据集
    axis:划分数据集的特征
    value:需要返回的特征值
Returns:
    无
Modify:
    2018-03-13

"""
def splitDataSet(dataSet,axis,value):
    #创建返回的数据集列表
    retDataSet=[]
    #遍历数据集
    for featVec in dataSet:
        if featVec[axis]==value:
            #去掉axis特征
            reduceFeatVec=featVec[:axis]
            #将符合条件的添加到返回的数据集
            reduceFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reduceFeatVec)
    #返回划分后的数据集
    return retDataSet

"""
函数说明:计算给定数据集的经验熵(香农熵)
Parameters:
    dataSet:数据集
Returns:
    shannonEnt:信息增益最大特征的索引值
Modify:
    2018-03-13

"""


def chooseBestFeatureToSplit(dataSet):
    #特征数量
    numFeatures = len(dataSet[0]) - 1
    #计数数据集的香农熵
    baseEntropy = calcShannonEnt(dataSet)
    #信息增益
    bestInfoGain = 0.0
    #最优特征的索引值
    bestFeature = -1
    #遍历所有特征
    for i in range(numFeatures):
        # 获取dataSet的第i个所有特征
        featList = [example[i] for example in dataSet]
        #创建set集合{},元素不可重复
        uniqueVals = set(featList)
        #经验条件熵
        newEntropy = 0.0
        #计算信息增益
        for value in uniqueVals:
            #subDataSet划分后的子集
            subDataSet = splitDataSet(dataSet, i, value)
            #计算子集的概率
            prob = len(subDataSet) / float(len(dataSet))
            #根据公式计算经验条件熵
            newEntropy += prob * calcShannonEnt((subDataSet))
        #信息增益
        infoGain = baseEntropy - newEntropy
        #打印每个特征的信息增益
        print("第%d个特征的增益为%.3f" % (i, infoGain))
        #计算信息增益
        if (infoGain > bestInfoGain):
            #更新信息增益,找到最大的信息增益
            bestInfoGain = infoGain
            #记录信息增益最大的特征的索引值
            bestFeature = i
            #返回信息增益最大特征的索引值
    return bestFeature

"""
函数说明:统计classList中出现次数最多的元素(类标签)
Parameters:
    classList:类标签列表
Returns:
    sortedClassCount[0][0]:出现次数最多的元素(类标签)
Modify:
    2018-03-13

"""
def majorityCnt(classList):
    classCount={}
    #统计classList中每个元素出现的次数
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
            classCount[vote]+=1
        #根据字典的值降序排列
        sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

"""
函数说明:创建决策树

Parameters:
    dataSet:训练数据集
    labels:分类属性标签
    featLabels:存储选择的最优特征标签
Returns:
    myTree:决策树
Modify:
    2018-03-13

"""
def createTree(dataSet,labels,featLabels):#创建树
    #取分类标签(是否放贷:yes or no)
    classList=[example[-1] for example in dataSet]
    #如果类别完全相同,则停止继续划分
    if classList.count(classList[0])==len(classList):
        return classList[0]
    #是否没有特征
    if len(dataSet[0])==1:
        # 遍历完所有特征时返回出现次数最多的类标签
        return majorityCnt(classList)
    #选择最优特征
    bestFeat=chooseBestFeatureToSplit(dataSet)
    #最优特征的标签
    bestFeatLabel=labels[bestFeat]
    featLabels.append(bestFeatLabel)
    #根据最优特征的标签生成树
    myTree={bestFeatLabel:{}}
    #删除已经使用的特征标签
    del(labels[bestFeat])
    #得到训练集中所有最优特征的属性值
    featValues=[example[bestFeat] for example in dataSet]
    #去掉重复的属性值
    uniqueVls=set(featValues)
    #遍历特征,创建决策树
    for value in uniqueVls:
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),
                                               labels,featLabels)
    return myTree

"""
函数说明:创建测试数据集
Parameters:无
Returns:
    dataSet:数据集
    labels:分类属性
Modify:
    2018-03-13

"""
#用上面的决策树来测试
def classify(inputTree,featLabels,testVec):
    # 得到树中的第一个特征
    global classLabel
    firstStr=list(inputTree.keys())[0]
    # 得到第一个对应的值
    secondDict=inputTree[firstStr]
    # 得到树中第一个特征对应的索引
    # index方法查找当前列表中第一个匹配firstStr变量的元素的索引
    featIndex = featLabels.index(firstStr)
    # 遍历树
    for key in secondDict.keys():
        # 如果在secondDict[key]中找到testVec[featIndex]
        if testVec[featIndex]==key:
            # 判断secondDict[key]是否为字典
            if type(secondDict[key]).__name__==dict:
                # 若为字典,递归的寻找testVec
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:
                # 若secondDict[key]为标签值,则将secondDict[key]赋给classLabel
                classLabel=secondDict[key]
    # 返回类标签
    return classLabel




def createDataSet():
    # 数据集
    dataSet=[[0, 0, 0, 0, no],
            [0, 0, 0, 1, no],
            [0, 1, 0, 1, yes],
            [0, 1, 1, 0, yes],
            [0, 0, 0, 0, no],
            [1, 0, 0, 0, no],
            [1, 0, 0, 1, no],
            [1, 1, 1, 1, yes],
            [1, 0, 1, 2, yes],
            [1, 0, 1, 2, yes],
            [2, 0, 1, 2, yes],
            [2, 0, 1, 1, yes],
            [2, 1, 0, 1, yes],
            [2, 1, 0, 2, yes],
            [2, 0, 0, 0, no]]
    #分类属性
    labels=[年龄,有工作,有自己的房子,信贷情况]
    #返回数据集和分类属性
    return dataSet,labels
if __name__==__main__:
    dataSet,labels=createDataSet()
    copy_Featlabels=labels[:]
    featLabels=[]
    myTree=createTree(dataSet,labels,featLabels)
    print(myTree)
    while True:
        try:
            test_Feat= input(请输入四个数字,每个数字为0或1,用空号隔开:)
            test_Feat= list(map(int,test_Feat.split( )))
            print(测试数据类别为:+str(classify(myTree,copy_Featlabels,test_Feat)))
        except:
            break

 

决策树

标签:sha   保存   dict   pen   float   元素   one   排列   出现   

原文地址:https://www.cnblogs.com/zlj843767688/p/12625468.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!