码迷,mamicode.com
首页 > 其他好文 > 详细

gbdt+lr代码

时间:2018-08-09 15:47:13      阅读:200      评论:0      收藏:0      [点我收藏+]

标签:ima   color   term   port   transform   9.png   sem   seed   图片   

import numpy as np np.random.seed(10) import matplotlib.pyplot as plt from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier, GradientBoostingClassifier) from sklearn.preprocessing import OneHotEncoder from sklearn.model_selection import train_test_split from sklearn.metrics import roc_curve from sklearn.pipeline import make_pipeline n_estimator = 10 X, y = make_classification(n_samples=80000) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5) X_train, X_train_lr, y_train, y_train_lr = train_test_split(X_train, y_train, test_size=0.5) rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator) rf_enc = OneHotEncoder() rf_lm = LogisticRegression() rf.fit(X_train, y_train) rf_enc.fit(rf.apply(X_train)) rf_lm.fit(rf_enc.transform(rf.apply(X_train_lr)), y_train_lr) y_pred_rf_lm = rf_lm.predict_proba(rf_enc.transform(rf.apply(X_test)))[:, 1] fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm) grd = GradientBoostingClassifier(n_estimators=n_estimator) grd_enc = OneHotEncoder() grd_lm = LogisticRegression() grd.fit(X_train, y_train) #GBDT建模 grd_enc.fit(grd.apply(X_train)[:, :, 0]) grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:, :, 0]), y_train_lr) y_pred_grd_lm = grd_lm.predict_proba( grd_enc.transform(grd.apply(X_test)[:, :, 0]))[:, 1] fpr_grd_lm, tpr_grd_lm, _ = roc_curve(y_test, y_pred_grd_lm) y_pred_grd = grd.predict_proba(X_test)[:, 1] fpr_grd, tpr_grd, _ = roc_curve(y_test, y_pred_grd) y_pred_rf = rf.predict_proba(X_test)[:, 1] fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_rf) plt.figure(2) plt.xlim(0, 0.2) plt.ylim(0.8, 1) plt.plot([0, 1], [0, 1], ‘k--‘) plt.plot(fpr_rf, tpr_rf, label=‘RF‘) plt.plot(fpr_rf_lm, tpr_rf_lm, label=‘RF + LR‘) plt.plot(fpr_grd, tpr_grd, label=‘GBT‘) plt.plot(fpr_grd_lm, tpr_grd_lm, label=‘GBT + LR‘) plt.xlabel(‘False positive rate‘) plt.ylabel(‘True positive rate‘) plt.title(‘ROC curve (zoomed in at top left)‘) plt.legend(loc=‘best‘) plt.show()

技术分享图片

gbdt+lr代码

标签:ima   color   term   port   transform   9.png   sem   seed   图片   

原文地址:http://blog.51cto.com/yixianwei/2156791

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!