码迷,mamicode.com
首页 > 其他好文 > 详细

《机器学习实战》——决策树

时间:2016-12-04 23:13:31      阅读:207      评论:0      收藏:0      [点我收藏+]

标签:print   create   ssl   排序   内存   开始   nbsp   flip   取值   

原理(ID3):

依次选定每个特征,计算信息增益(基本信息熵-当前信息熵),选择信息增益最大的一个作为最佳特征;

以该特征作为树的根节点,以该最佳特征的每一个值作为分支,建立子树;

重复上述过程,直到:1) 所有类别一致 2) 特征用尽

优点:

简单容易理解;

可处理有缺失值的特征、非数值型数据;

与训练数据拟合很好,复杂度小

缺点:

选择特征时候需要多次扫描与排序,适合常驻内存的小数据集

容易过拟合

改进:

ID3算法偏向取值较多的特征(宽且浅的树),适合离散型数据,难处理连续型数据。香农熵里面概率的统计根本上是在统计不同分类的概率。

C4.5用信息增益比筛选特征,分母为-∑p*log(p), p为: 特征A某个取值包含的数据行数/总数据行数,可处理缺失值和连续型数据

CARTGini系数筛选特征。同时是二叉决策树,能处理离散特征、连续特征、分类问题、回归问题

Gini系数(杂质度量法):Gini(A)=1-∑(Pk)2. Pk表示观测点中k类的概率,当Gini(A)=0时只有一类,当所有类同概率出现时最大Gini(A)=(C-1)C/2。

如果目标变量是标称的,并且是具有两个以上的类别,则CART可能考虑将目标类别合并成两个超类别(双化);
如果目标变量是连续的,则CART算法找出一组基于树的回归方程来预测目标变量。

事后剪枝,停止条件是:1) 样本个数小于预定阀值 2) 样本的Gini系数小于预定阀值 3)没有更多特征

代码:

 1 #coding: utf-8
 2 from __future__ import division
 3 from numpy import *
 4 
 5 
 6 class myClass(object):
 7     def __init__(self):
 8         group, labels = self.loadData()
 9         myTree = self.createTree(group, labels)
10         print myTree
11 
12     def loadData(self):
13         group = [[1,1,"yes"],
14                 [1,1,"yes"],
15                 [1,0,"no"],
16                 [0,1,"no"],
17                 [0,1,"no"],
18                 ]
19         labels = ["no surfacing", "flippers"]     # 表示特征的含义:海洋生物不露出水面是否可以生存,是否有脚蹼
20         return group, labels
21 
22     def calShannonEnt(self, group):
23         numEntrories = len(group)
24         labelCount = dict()
25         for feaVec in group:
26             currentLabel = feaVec[-1]
27             labelCount[currentLabel] = labelCount.get(currentLabel, 0) + 1
28         shannonEnt = 0.0                            # 用之前定义
29         for key in labelCount.keys():
30             prob = labelCount[key]/numEntrories
31             shannonEnt -= prob * log(prob)
32         return shannonEnt
33 
34     def splitDataSet(self, group, axis, value):     # 特征所有取值的信息熵之和才有意义,这里只计算条件熵。
35         retDataSet = []
36         for feaVec in group:
37             if feaVec[axis] == value:
38                 retDataSet.append(feaVec[:axis] + feaVec[axis+1:])
39         return retDataSet
40 
41     def chooseBestFeature(self, group):
42         numFeatures = len(group[0]) - 1
43         baseEntrory = self.calShannonEnt(group)
44         bestInfoGain = 0.0; bestFeature = -1
45         for i in range(numFeatures):
46             uniqueVals = set([it[i] for it in group])
47             newEntrory = 0.0
48             for value in uniqueVals:
49                 subGroup = self.splitDataSet(group, i, value)
50                 prob = len(subGroup)/len(group)
51                 newEntrory += prob * self.calShannonEnt(subGroup)
52             infoGain = baseEntrory - newEntrory
53             if (infoGain > bestInfoGain):
54                 bestInfoGain = infoGain; bestFeature = i
55         return bestFeature
56 
57     def majority(self, classList):
58         classCountDict = {}
59         for vote in classList:
60             classCountDict[vote] = classCountDict.get(vote, 0) + 1
61         return sorted(classCountDict.items(), key = lambda x:x[1], reverse = True)[0][0]
62 
63     def createTree(self, myGroup, labels):
64         # 1.两个终止条件; 2.建立根树(求得最佳根特征)并从labels中删除根标签(取得根标签); 3.根据根特征的每个值建立子树(取得唯一特征值)
65         classList = [it[-1] for it in myGroup]
66         if classList.count(classList[0]) == len(classList):
67             return classList[0]
68         if len(myGroup[0]) == 1:
69             return self.majorityCnt(classList)
70         rootFeature = self.chooseBestFeature(myGroup)
71         rootLabel = labels[rootFeature]
72         myTree = {rootLabel:{}}
73         del(labels[rootFeature])
74         uniqueVals = set([it[rootFeature] for it in myGroup])
75         for val in uniqueVals:      # 开始递归创建子树
76             subLabels = labels[:]   # 每次都要定义新的subLabels
77             myTree[rootLabel][val] = self.createTree(self.splitDataSet(myGroup, rootFeature, val), subLabels)  # 函数作为参数传入
78         return myTree
79 
80 
81 if __name__ == __main__:
82     A = myClass()

 

《机器学习实战》——决策树

标签:print   create   ssl   排序   内存   开始   nbsp   flip   取值   

原文地址:http://www.cnblogs.com/fresh-bird/p/6132015.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!