码迷,mamicode.com
首页 > 其他好文 > 详细

sklearn获得某个参数的不同取值在训练集和测试集上的表现的曲线刻画

时间:2018-10-31 22:37:32      阅读:359      评论:0      收藏:0      [点我收藏+]

标签:atp   training   space   learn   获得   bsp   span   描述   model   

from sklearn.svm import SVC
from sklearn.datasets import make_classification
import numpy as np

X,y = make_classification()


def plot_validation_curve(estimator,X,y,param_name="gamma",
                          param_range=np.logspace(-6,-1,5),cv=5,scoring="accuracy"):
    """
    描述:获得某个参数的不同取值在训练集和测试集上的表现
    """
    from sklearn.model_selection import validation_curve
    import matplotlib.pyplot as plt
    
    train_scores,test_scores = validation_curve(estimator=estimator, 
                                                X=X, 
                                                y=y, 
                                                cv=cv,
                                                scoring=scoring,
                                                param_name=param_name,
                                                param_range=param_range)
    
    train_scores_mean = np.mean(train_scores, axis=1)
    train_scores_std  = np.std(train_scores, axis=1)
    test_scores_mean  = np.mean(test_scores, axis=1)
    test_scores_std   = np.std(test_scores, axis=1)
    
    plt.title("Validation Curve")
    plt.xlabel("$\gamma$")
    plt.ylabel("Score")
    plt.ylim(0.0, 1.1)
    
    plt.semilogx(param_range,train_scores_mean,label="Training score",color="darkorange", lw=2)
    plt.fill_between(param_range,
                     train_scores_mean-train_scores_std,
                     train_scores_mean+train_scores_std,
                     alpha=0.2,
                     color="darkorange", 
                     lw=2)
    
    plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",color="navy", lw=2)    
    plt.fill_between(param_range, 
                     test_scores_mean - test_scores_std,
                     test_scores_mean + test_scores_std, 
                     alpha=0.2,
                     color="navy", 
                     lw=2)
    
    plt.legend(loc="best")
    plt.show()
    

    
plot_validation_curve(estimator=SVC(),
                      X=X,y=y,
                      param_name="gamma",
                      param_range=np.logspace(-6,-1,5),cv=5,scoring="accuracy")    
    

 

sklearn获得某个参数的不同取值在训练集和测试集上的表现的曲线刻画

标签:atp   training   space   learn   获得   bsp   span   描述   model   

原文地址:https://www.cnblogs.com/wzdLY/p/9886270.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!