码迷,mamicode.com
首页 > 其他好文 > 详细

交叉验证

时间:2018-12-12 20:33:37      阅读:221      评论:0      收藏:0      [点我收藏+]

标签:col   com   查看   迭代   target   利用   amp   参考   atp   

PS:    sklearn.cross_validation已经被移除所以改为从 sklearn.model_selection 中调用train_test_split 等函数

  K层交叉检验就是把原始的数据随机分成K个部分。在这K个部分中,选择一个作为测试数据,剩下的K-1

作为训练数据。 交叉检验的过程实际上是把实验重复做K次,每次实验都从K个部分中选择一个不同的部分

作为测试数据(保证K个部分的数据都分别做过测试数据),剩下的K-1个当作训练数据进行实验,最后把

得到的K个实验结果平均。

  验证模型的精确度,或者某个参数的影响

scores = cross_val_score(knn, X, y, cv=10, scoring=‘accuracy‘)

其中accuracy一般用于分类问题的评估,而均方差一般用于回归问题的评估。

参考代码如下:

import matplotlib.pyplot as plt #可视化模块

#建立测试参数集
k_range = range(1, 31)

k_scores = []

#藉由迭代的方式来计算不同参数对模型的影响,并返回交叉验证后的平均准确率
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    scores = cross_val_score(knn, X, y, cv=10, scoring=‘accuracy‘)
  #X,y是原始数据集未分割,knn为初始模型未训练
    k_scores.append(scores.mean())

#可视化数据
plt.plot(k_range, k_scores)
plt.xlabel(‘Value of K for KNN‘)
plt.ylabel(‘Cross-Validated Accuracy‘)
plt.show()

 

交叉验证时候查看训练过程中的学习效果,利用Learning curve 检视过拟合

1 train_sizes, train_loss, test_loss = learning_curve(
2     SVC(gamma=0.001), X, y, cv=10, scoring=mean_squared_error,
3     train_sizes=[0.1, 0.25, 0.5, 0.75, 1])
4 
5 #平均每一轮所得到的平均方差(共5轮,分别为样本10%、25%、50%、75%、100%)

train_size的参数是指在训练数据的10%,25%时候的误差。

 

 

主要内容:在调节模型的某个参数时候可以用validation_curve来测试Loss值的变化

from sklearn.learning_curve import validation_curve #validation_curve模块
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np

#digits数据集
digits = load_digits()
X = digits.data
y = digits.target

#建立参数测试集
param_range = np.logspace(-6, -2.3, 5)

#使用validation_curve快速找出参数对模型的影响
train_loss, test_loss = validation_curve(
    SVC(), X, y, param_name=‘gamma‘, param_range=param_range, cv=10, scoring=‘mean_squared_error‘)

#平均每一轮的平均方差
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

#可视化图形
plt.plot(param_range, train_loss_mean, ‘o-‘, color="r",
         label="Training")
plt.plot(param_range, test_loss_mean, ‘o-‘, color="g",
        label="Cross-validation")

plt.xlabel("gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

param_range为参数的取值范围以log形式表现, 模型中SVC()不传参数,在param_name中传入要测试的参数名称,在param-range赋予取值范围。

 

交叉验证

标签:col   com   查看   迭代   target   利用   amp   参考   atp   

原文地址:https://www.cnblogs.com/ywheunji/p/10110039.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!