标签:square col ram dataset dig random height img code
还是图片显示~
import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl import seaborn as sns mpl.rcParams[‘font.sans-serif‘] = [u‘SimHei‘] mpl.rcParams[‘axes.unicode_minus‘] = False from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error #制造一些数据 x_data = np.random.rand(200,5)*10 w = np.array([2,4,6,8,10]) y = np.dot(x_data,w) + np.random.rand(200)*20 + 20 x_train,x_test,y_train,y_test = train_test_split(x_data,y) #先试试决策回归树吧,都没试过 from sklearn.tree import DecisionTreeRegressor dtr = DecisionTreeRegressor() dtr.fit(x_train,y_train) y_hat = dtr.predict(x_test) print(‘决策回归树的误差:‘,mean_squared_error(y_test,y_hat)) #画图 fig,ax = plt.subplots() fig.set_size_inches(10,6) ax.plot(np.arange(len(y_hat)),y_hat,color = ‘r‘) ax.plot(np.arange(len(y_hat)),y_test,color = ‘g‘)
决策树做回归也太差了吧,难道是我调参有问题吗?一会试试调参看看
决策回归树的误差: 667.87208618
#调参一下试试 from sklearn.model_selection import GridSearchCV model_dtr = GridSearchCV(dtr,param_grid=({‘max_depth‘:np.arange(1,50)}),cv=10) model_dtr.fit(x_train,y_train) y_hat = model_dtr.predict(x_test) print(model_dtr.best_params_) print(‘决策回归树的误差:‘,mean_squared_error(y_test,y_hat))
还是没啥用,好差的效果,同样的数据,前面线性回归的均方误差才二十几
决策回归树的误差: 643.989924585
from sklearn.datasets import load_digits from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score digits = load_digits() x_data = digits.data y_data = digits.target x_train,x_test,y_train,y_test = train_test_split(x_data,y_data) dtc = DecisionTreeClassifier() dtc.fit(x_train,y_train) y_hat = dtc.predict(x_test) print(‘正确率‘,accuracy_score(y_hat,y_test)) #调个参数看看 model_dtc = GridSearchCV(dtc,param_grid=({‘max_depth‘:np.arange(1,20)}),cv=10) model_dtc.fit(x_train,y_train) print(‘最佳参数‘,model_dtc.best_params_) y_hat = model_dtc.predict(x_test) y_hat = dtc.predict(x_test) print(‘正确率‘,accuracy_score(y_hat,y_test))
正确率 0.817777777778
最佳参数 {‘max_depth‘: 18}
正确率 0.817777777778
决策树还是弱分类器啊,难怪都喜欢用它来做ensemble
标签:square col ram dataset dig random height img code
原文地址:https://www.cnblogs.com/slowlyslowly/p/8810958.html