标签:代码 src show 有一个 cal strong sha ref amp
有两种决策树模型,一种是Classification另一种是Regression,两种决策树原理不同,这次主要学的是后者。
说的简单通俗一些,Decision Tree就是将数据切根据特征值切分成不同的区域,如果一个需要预测的数据掉在了某个区域里,就取一个区域中所有数据的平均值作为模型的预测值。
虽然讲起来很简单,但是运算的原理还是很高深的,这里先不深入,等我深入学习了statistical learning modle后再来update。
现在已经逐渐习惯了python的回归模型建模,都是从sklearn
的某个module中导入对应的class
然后创建一个regressor
instance,最后fit()
到训练集上,回归模型就算创建完毕了。
建立Decision Tree需要用到的是sklearn.tree
中的DecisionTreeRegressor
class。
1 | from sklearn.tree import DecisionTreeRegressor |
和之前差不多,绘制Decision Tree模型是先把数据用散点图绘制出来,然后再把预测值画出来加到图上用线段表示作为模型。Decision Tree模型在一维的时候呈现出阶梯的形状。
1 |
|
我们绘制之前的线性模型的时候只要用下列代码就可以了
1 | # Visualising the Decision Tree Regression results |
发现得到的图表中Decision Tree是一条连续的倾斜的线。这不符合Decision Tree模型预测值,Decision Tree是在某一个区域内(在这个数据集中,因为特征值只有一个维度,就是在X上取不同的区间)然后取该区域的平均值作为所有点的预测值。但是图上却不是这样,在一个区域内的预测值是不同的。
这个问题的原因所在就是我们只画了1到10的十个点的预测值,不同的点之间并没有点绘制出来,matplotlib
只能用线段将两个点连起来,所以就出现了倾斜的线。这种模型与之前的两种模型都不同,既不是Linear也不是Continuous,是Non-Linear & Non-continous模型。
对症下药,解决这个问题的方法就是将更多的点绘制出来,所以新建一个数列,把数据间隔变小,有更多的中间值,这样绘制出来才能看到正确的Decision Tree模型。
原文:大专栏 Decision Tree Regression
标签:代码 src show 有一个 cal strong sha ref amp
原文地址:https://www.cnblogs.com/wangziqiang123/p/11632139.html