码迷,mamicode.com
首页 > 其他好文 > 详细

Decision Tree Regression

时间:2019-10-07 21:23:31      阅读:67      评论:0      收藏:0      [点我收藏+]

标签:代码   src   show   有一个   cal   strong   sha   ref   amp   

技术图片

Decision Tree Regression

Decision Tree Intuition

技术图片

有两种决策树模型,一种是Classification另一种是Regression,两种决策树原理不同,这次主要学的是后者。

技术图片
说的简单通俗一些,Decision Tree就是将数据切根据特征值切分成不同的区域,如果一个需要预测的数据掉在了某个区域里,就取一个区域中所有数据的平均值作为模型的预测值。

虽然讲起来很简单,但是运算的原理还是很高深的,这里先不深入,等我深入学习了statistical learning modle后再来update。

python

建立模型

现在已经逐渐习惯了python的回归模型建模,都是从sklearn的某个module中导入对应的class然后创建一个regressor instance,最后fit()到训练集上,回归模型就算创建完毕了。

建立Decision Tree需要用到的是sklearn.tree中的DecisionTreeRegressor class。

1
2
3
from sklearn.tree import DecisionTreeRegressor
regressor = DecisionTreeRegressor(random_state = 0)
regressor.fit(X, y)

技术图片

可视化

和之前差不多,绘制Decision Tree模型是先把数据用散点图绘制出来,然后再把预测值画出来加到图上用线段表示作为模型。Decision Tree模型在一维的时候呈现出阶梯的形状。

1
2
3
4
5
6
7
8

X_grid = np.arange(min(X.values), max(X.values), 0.01)
X_grid = X_grid.reshape((len(X_grid), 1))
plt.scatter(X, y, color = 'red')
plt.plot(X_grid, regressor.predict(X_grid), color = 'blue')
plt.title('Truth or Bluff (Decision Tree Regression Model)')
plt.xlabel('Position level')
plt.ylabel('Salary')

技术图片

Decision Tree可视化的问题

我们绘制之前的线性模型的时候只要用下列代码就可以了

1
2
3
4
5
6
7
# Visualising the Decision Tree Regression results
plt.scatter(X, y, color = 'red')
plt.plot(X, regressor.predict(X), color = 'blue')
plt.title('Truth or Bluff (Regression Model)')
plt.xlabel('Position level')
plt.ylabel('Salary')
plt.show()

技术图片

发现得到的图表中Decision Tree是一条连续的倾斜的线。这不符合Decision Tree模型预测值,Decision Tree是在某一个区域内(在这个数据集中,因为特征值只有一个维度,就是在X上取不同的区间)然后取该区域的平均值作为所有点的预测值。但是图上却不是这样,在一个区域内的预测值是不同的。

这个问题的原因所在就是我们只画了1到10的十个点的预测值,不同的点之间并没有点绘制出来,matplotlib只能用线段将两个点连起来,所以就出现了倾斜的线。这种模型与之前的两种模型都不同,既不是Linear也不是Continuous,是Non-Linear & Non-continous模型。

对症下药,解决这个问题的方法就是将更多的点绘制出来,所以新建一个数列,把数据间隔变小,有更多的中间值,这样绘制出来才能看到正确的Decision Tree模型。

原文:大专栏  Decision Tree Regression


Decision Tree Regression

标签:代码   src   show   有一个   cal   strong   sha   ref   amp   

原文地址:https://www.cnblogs.com/wangziqiang123/p/11632139.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!