码迷,mamicode.com
首页 > 其他好文 > 详细

决策树代码(转)

时间:2015-09-29 22:03:32      阅读:209      评论:0      收藏:0      [点我收藏+]

标签:

转自http://blog.csdn.net/fy2462/article/details/31762429

一、前言

       当年实习公司布置了一个任务让写一个决策树,以前并未接触数据挖掘的东西,但作为一个数据挖掘最基本的知识点,还是应该有所理解的。

  程序的源码可以点击这里进行下载,下面简要介绍一下决策树以及相关算法概念。

  决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。 数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测(就像上面的银行官员用他来预测贷款风险)。从数据产生决策树的机器学习技术叫做决策树学习, 通俗说就是决策树。(来自维基百科)

  1986年Quinlan提出了著名的ID3算法。在ID3算法的基础上,1993年Quinlan又提出了C4.5算法。为了适应处理大规模数据集的需要,后来又提出了若干改进的算法,其中SLIQ (super-vised learning in quest)和SPRINT (scalable parallelizableinduction of decision trees)是比较有代表性的两个算法,此处暂且略过。

  本文实现了C4.5的算法,在ID3的基础上计算信息增益,从而更加准确的反应信息量。其实通俗的说就是构建一棵加权的最短路径Haffman树,让权值最大的节点为父节点。

二、基本概念

  下面简要介绍一下ID3算法:

  ID3算法的核心是:在决策树各级结点上选择属性时,用信息增益(information gain)作为属性的选择标准,以使得在每一个非叶结点进行测试时,能获得关于被测试记录最大的类别信息。

  其具体方法是:检测所有的属性,选择信息增益最大的属性产生决策树结点,由该属性的不同取值建立分支,再对各分支的子集递归调用该方法建立决策树结点的分支,直到所有子集仅包含同一类别的数据为止。最后得到一棵决策树,它可以用来对新的样本进行分类。

  某属性的信息增益按下列方法计算:

技术分享

 

      信息熵是香农提出的,用于描述信息不纯度(不稳定性),其计算公式是Info(D)。

  其中:Pi为子集合中不同性(而二元分类即正样例和负样例)的样例的比例;j是属性A中的索引,D是集合样本,Dj是D中属性A上值等于j的样本集合。

      这样信息收益可以定义为样本按照某属性划分时造成熵减少的期望,可以区分训练样本中正负样本的能力。信息增益定义为结点与其子结点的信息熵之差,公式为Gain(A)。

  ID3算法的优点是:算法的理论清晰,方法简单,学习能力较强。其缺点是:只对比较小的数据集有效,且对噪声比较敏感,当训练数据集加大时,决策树可能会随之改变。

  C4.5算法继承了ID3算法的优点,并在以下几方面对ID3算法进行了改进:

  1) 用信息增益率来选择属性,克服了用信息增益选择属性时偏向选择取值多的属性的不足,公式为GainRatio(A);

  2) 在树构造过程中进行剪枝;

  3) 能够完成对连续属性的离散化处理;

  4) 能够对不完整数据进行处理。

  C4.5算法与其它分类算法如统计方法、神经网络等比较起来有如下优点:产生的分类规则易于理解,准确率较高。其缺点是:在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导致算法的低效。此外,C4.5只适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时程序无法运行。

技术分享

三、数据集

实现的C4.5数据集合如下:

技术分享

 

它记录了再不同的天气状况下,是否出去觅食的数据。

四、程序代码

  程序引入状态树作为统计和计算属性的数据结构,它记录了每次计算后,各个属性的统计数据,其定义如下:

 

  1. struct attrItem  
  2. {  
  3.    std::vector<int>  itemNum;  //itemNum[0] = itemLine.size()  
  4.                                //itemNum[1] = decision num  
  5.    set<int>          itemLine;  
  6. };  
  7.   
  8. struct attributes  
  9. {  
  10.    string attriName;  
  11.    vector<double> statResult;  
  12.    map<string, attrItem*> attriItem;  
  13. };   
  14.   
  15. vector<attributes*> statTree;  


决策树节点数据结构如下:

 

  1. struct TreeNode   
  2. {  
  3.     std::string               m_sAttribute;  
  4.     int                       m_iDeciNum;  
  5.     int                       m_iUnDecinum;  
  6.     std::vector<TreeNode*>    m_vChildren;      
  7. };  


程序源码如下所示(程序中有详细注解):

 

  1. #include "DecisionTree.h"  
  2.   
  3. int main(int argc, char* argv[]){  
  4.     string filename = "source.txt";  
  5.     DecisionTree dt ;  
  6.     int attr_node = 0;  
  7.     TreeNode* treeHead = nullptr;  
  8.     set<int> readLineNum;  
  9.     vector<int> readClumNum;  
  10.     int deep = 0;  
  11.     if (dt.pretreatment(filename, readLineNum, readClumNum) == 0)  
  12.     {  
  13.         dt.CreatTree(treeHead, dt.getStatTree(), dt.getInfos(), readLineNum, readClumNum, deep);  
  14.     }  
  15.     return 0;  
  16. }  
  17. /* 
  18. * @function CreatTree 预处理函数,负责读入数据,并生成信息矩阵和属性标记 
  19. * @param: filename 文件名 
  20. * @param: readLineNum 可使用行set 
  21. * @param: readClumNum 可用属性set 
  22. * @return int 返回函数执行状态 
  23. */  
  24. int DecisionTree::pretreatment(string filename, set<int>& readLineNum, vector<int>& readClumNum)  
  25. {  
  26.     ifstream read(filename.c_str());  
  27.     string itemline = "";  
  28.     getline(read, itemline);  
  29.     istringstream iss(itemline);  
  30.     string attr = "";  
  31.     while(iss >> attr)  
  32.     {  
  33.         attributes* s_attr = new attributes();  
  34.         s_attr->attriName = attr;  
  35.         //初始化属性名  
  36.         statTree.push_back(s_attr);  
  37.         //初始化属性映射  
  38.         attr_clum[attr] = attriNum;  
  39.         attriNum++;  
  40.         //初始化可用属性列  
  41.         readClumNum.push_back(0);  
  42.         s_attr = nullptr;  
  43.     }  
  44.   
  45.     int i  = 0;  
  46.     //添加具体数据  
  47.     while(true)  
  48.     {  
  49.         getline(read, itemline);  
  50.         if(itemline == "" || itemline.length() <= 1)  
  51.         {  
  52.             break;  
  53.         }  
  54.         vector<string> infoline;  
  55.         istringstream stream(itemline);  
  56.         string item = "";  
  57.         while(stream >> item)  
  58.         {  
  59.             infoline.push_back(item);  
  60.         }  
  61.   
  62.         infos.push_back(infoline);  
  63.         readLineNum.insert(i);  
  64.         i++;  
  65.     }  
  66.     read.close();  
  67.     return 0;  
  68. }  
  69.   
  70. int DecisionTree::statister(vector<vector<string>>& infos, vector<attributes*>& statTree,   
  71.                             set<int>& readLine, vector<int>& readClumNum)  
  72. {  
  73.     //yes的总行数  
  74.     int deciNum = 0;  
  75.     //统计每一行  
  76.     set<int>::iterator iter_end = readLine.end();  
  77.     for (set<int>::iterator line_iter = readLine.begin(); line_iter != iter_end; ++line_iter)  
  78.     {  
  79.         bool decisLine = false;  
  80.         if (infos[*line_iter][attriNum - 1] == "yes")  
  81.         {  
  82.             decisLine = true;  
  83.             deciNum++;   
  84.         }  
  85.         //如果该列未被锁定并且为属性列,进行统计  
  86.         for (int i = 0; i < attriNum - 1; i++)  
  87.         {  
  88.             if (readClumNum[i] == 0)  
  89.             {  
  90.                 std::string tempitem = infos[*line_iter][i];  
  91.                 auto map_iter = statTree[i]->attriItem.find(tempitem);  
  92.                 //没有找到  
  93.                 if (map_iter == (statTree[i]->attriItem).end())  
  94.                 {  
  95.                     //新建  
  96.                     attrItem* attritem = new attrItem();  
  97.                     attritem->itemNum.push_back(1);  
  98.                     decisLine ? attritem->itemNum.push_back(1) : attritem->itemNum.push_back(0);  
  99.                     attritem->itemLine.insert(*line_iter);  
  100.                     //建立属性名->item映射  
  101.                     (statTree[i]->attriItem)[tempitem] = attritem;  
  102.                     attritem = nullptr;  
  103.                 }  
  104.                 else  
  105.                 {  
  106.                     (map_iter->second)->itemNum[0]++;  
  107.                     (map_iter->second)->itemLine.insert(*line_iter);  
  108.                     if(decisLine)  
  109.                     {  
  110.                         (map_iter->second)->itemNum[1]++;  
  111.                     }  
  112.                 }  
  113.             }  
  114.         }  
  115.     }  
  116.     return deciNum;  
  117. }  
  118.   
  119. /* 
  120. * @function CreatTree 递归DFS创建并输出决策树 
  121. * @param: treeHead 为生成的决定树 
  122. * @param: statTree 为状态树,此树动态更新,但是由于是DFS对数据更新,所以不必每次新建状态树 
  123. * @param: infos 数据信息 
  124. * @param: readLine 当前在infos中所要进行统计的行数,由函数外给出 
  125. * @param: deep 决定树的深度,用于打印 
  126. * @return void 
  127. */  
  128. void DecisionTree::CreatTree(TreeNode* treeHead, vector<attributes*>& statTree, vector<vector<string>>& infos,   
  129.                              set<int>& readLine, vector<int>& readClumNum, int deep)  
  130. {  
  131.     //有可统计的行  
  132.     if (readLine.size() != 0)  
  133.     {  
  134.         string treeLine = "";  
  135.         for (int i = 0; i < deep; i++)  
  136.         {  
  137.             treeLine += "--";  
  138.         }  
  139.         //清空其他属性子树,进行递归  
  140.         resetStatTree(statTree, readClumNum);  
  141.         //统计当前readLine中的数据:包括统计哪几个属性、哪些行,  
  142.         //并生成statTree(由于公用一个statTree,所有用引用代替),并返回目的信息数  
  143.         int deciNum = statister(getInfos(), statTree, readLine, readClumNum);  
  144.         int lineNum = readLine.size();  
  145.         int attr_node = compuDecisiNote(statTree, deciNum, lineNum, readClumNum);//本条复制为局部变量  
  146.         //该列被锁定  
  147.         readClumNum[attr_node] = 1;  
  148.         //建立树根  
  149.         TreeNode* treeNote = new TreeNode();  
  150.         treeNote->m_sAttribute = statTree[attr_node]->attriName;  
  151.         treeNote->m_iDeciNum = deciNum;  
  152.         treeNote->m_iUnDecinum = lineNum - deciNum;  
  153.         if (treeHead == nullptr)  
  154.         {  
  155.             treeHead = treeNote; //树根  
  156.         }  
  157.         else  
  158.         {  
  159.             treeHead->m_vChildren.push_back(treeNote); //子节点  
  160.         }  
  161.         cout << "节点-"<< treeLine << ">" << statTree[attr_node]->attriName    << endl;  
  162.           
  163.         //从孩子分支进行递归  
  164.         for(map<string, attrItem*>::iterator map_iterator = statTree[attr_node]->attriItem.begin();  
  165.             map_iterator != statTree[attr_node]->attriItem.end(); ++map_iterator)  
  166.         {  
  167.             //打印分支  
  168.             int sum = map_iterator->second->itemNum[0];  
  169.             int deci_Num = map_iterator->second->itemNum[1];  
  170.             cout << "分支--"<< treeLine << ">" << map_iterator->first << endl;  
  171.             //递归计算、创建  
  172.             if (deci_Num != 0 && sum != deci_Num )  
  173.             {  
  174.                 //计算有效行数  
  175.                 set<int> newReadLineNum = map_iterator->second->itemLine;  
  176.                 //DFS  
  177.                 CreatTree(treeNote, statTree, infos, newReadLineNum, readClumNum, deep + 1);  
  178.             }  
  179.             else  
  180.             {  
  181.                 //建立叶子节点  
  182.                 TreeNode* treeEnd = new TreeNode();  
  183.                 treeEnd->m_sAttribute = statTree[attr_node]->attriName;  
  184.                 treeEnd->m_iDeciNum = deci_Num;  
  185.                 treeEnd->m_iUnDecinum = sum - deci_Num;  
  186.                 treeNote->m_vChildren.push_back(treeEnd);  
  187.                 //打印叶子  
  188.                 if (deci_Num == 0)  
  189.                 {  
  190.                     cout << "叶子---"<< treeLine << ">no" << endl;  
  191.                 }  
  192.                 else  
  193.                 {  
  194.                     cout << "叶子---"<< treeLine << ">yes" << endl;  
  195.                 }  
  196.             }  
  197.         }  
  198.         //还原属性列可用性  
  199.         readClumNum[attr_node] = 0;  
  200.     }  
  201. }  
  202. /* 
  203. * @function compuDecisiNote 计算C4.5 
  204. * @param: statTree 为状态树,此树动态更新,但是由于是DFS对数据更新,所以不必每次新建状态树 
  205. * @param: deciNum Yes的数据量 
  206. * @param: lineNum 计算set的行数 
  207. * @param: readClumNum 用于计算的set 
  208. * @return int 信息量最大的属性号 
  209. */  
  210. int DecisionTree::compuDecisiNote(vector<attributes*>& statTree, int deciNum, int lineNum, vector<int>& readClumNum)  
  211. {  
  212.     double max_temp = 0;  
  213.     int max_attribute = 0;  
  214.     //总的yes行的信息量  
  215.     double infoD = info_D(deciNum, lineNum);  
  216.     for (int i = 0; i < attriNum - 1; i++)  
  217.     {  
  218.         if (readClumNum[i] == 0)  
  219.         {  
  220.             double splitInfo = 0.0;  
  221.             //info  
  222.             double info_temp = Info_attr(statTree[i]->attriItem, splitInfo, lineNum);  
  223.             statTree[i]->statResult.push_back(info_temp);  
  224.             //gain  
  225.             double gain_temp = infoD - info_temp;  
  226.             statTree[i]->statResult.push_back(gain_temp);  
  227.             //split_info  
  228.             statTree[i]->statResult.push_back(splitInfo);  
  229.             //gain_info  
  230.             double temp = gain_temp / splitInfo;  
  231.             statTree[i]->statResult.push_back(temp);  
  232.             //得到最大值*/  
  233.             if (temp > max_temp)  
  234.             {  
  235.                 max_temp = temp;  
  236.                 max_attribute = i;  
  237.             }  
  238.         }  
  239.     }  
  240.     return max_attribute;  
  241. }  
  242. /* 
  243. * @function Info_attr info_D 总信息量 
  244. * @param: deciNum 有效信息数 
  245. * @param: sum 总信息量 
  246. * @return double 总信息量比例 
  247. */  
  248. double DecisionTree::info_D(int deciNum, int sum)  
  249. {  
  250.     double pi = (double)deciNum / (double)sum;  
  251.     double result = 0.0;  
  252.     if (pi == 1.0 || pi == 0.0)  
  253.     {  
  254.         return result;  
  255.     }  
  256.     result = pi * (log(pi) / log((double)2)) + (1 - pi)*(log(1 - pi)/log((double)2));  
  257.     return -result;  
  258. }  
  259. /* 
  260. * @function Info_attr 总信息量 
  261. * @param: deciNum 有效信息数 
  262. * @param: sum 总信息量 
  263. * @return double  
  264. */  
  265. double DecisionTree::Info_attr(map<string, attrItem*>& attriItem, double& splitInfo, int lineNum)  
  266. {  
  267.     double result = 0.0;  
  268.     for (map<string, attrItem*>::iterator item = attriItem.begin();  
  269.          item != attriItem.end();  
  270.          ++item  
  271.         )  
  272.     {  
  273.          double pi = (double)(item->second->itemNum[0]) / (double)lineNum;  
  274.          splitInfo += pi * (log(pi) / log((double)2));  
  275.          double sub_attr = info_D(item->second->itemNum[1], item->second->itemNum[0]);  
  276.          result += pi * sub_attr;  
  277.     }  
  278.     splitInfo = -splitInfo;  
  279.     return result;  
  280. }  
  281. /* 
  282. * @function resetStatTree 清理状态树 
  283. * @param: statTree 状态树 
  284. * @param: readClumNum 需要清理的属性set 
  285. * @return void 
  286. */  
  287. void DecisionTree::resetStatTree(vector<attributes*>& statTree, vector<int>& readClumNum)  
  288. {  
  289.     for (int i = 0; i < readClumNum.size() - 1; i++)  
  290.     {  
  291.         if (readClumNum[i] == 0)  
  292.         {  
  293.             map<string, attrItem*>::iterator it_end = statTree[i]->attriItem.end();  
  294.             for (map<string, attrItem*>::iterator it = statTree[i]->attriItem.begin();  
  295.                 it != it_end; it++)  
  296.             {  
  297.                 delete it->second;  
  298.             }  
  299.             statTree[i]->attriItem.clear();  
  300.             statTree[i]->statResult.clear();  
  301.         }  
  302.     }  
  303. }  

 

五、结果分析

程序输出结果为:

 

技术分享

以图形表示为:

技术分享

 

六、小结:

 

  1、在设计程序时,对程序逻辑有时会发生混乱,·后者在纸上仔细画了些草图才解决这些问题,画一个好图可以有效的帮助你理解程序的流程以及逻辑脉络,是需求分析时最为关键的基本功。

  2、在编写程序之初,一直在纠结用什么样的数据结构,后来经过几次在编程实现推敲,才确定最佳的数据结构,可见数据结构在程序中的重要性。

  3、决策树的编写,其实就是理论与实践的相结合,虽然理论上比较简单,但是实践中却会遇到这样那样的问题,而这些问题就是考验一个程序员对最基本的数据结构、算法的理解和熟练程度,所以,勤学勤练基本功依然是关键。

  4、程序的效率还有待提高,欢迎各路高手指正。

决策树代码(转)

标签:

原文地址:http://www.cnblogs.com/yanjunhelloworld/p/4847231.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!