标签:验证 div ace alt logs ati svm bubuko ring
使用cross_val_score可以做,learning_curve,validition_curve也可以。
from sklearn.datasets import load_iris from sklearn.cross_validation import cross_val_score from sklearn.neighbors import KNeighborsClassifier import matplotlib.pyplot as plt %matplotlib inline iris = load_iris() x_data = iris.data y_data = iris.target k_score = [] for k in range(1,31): knn = KNeighborsClassifier(n_neighbors=k) score = cross_val_score(knn,x_data,y_data,cv=10,scoring=‘accuracy‘) k_score.append(score.mean()) plt.figure() plt.plot(range(1,31),k_score)
from sklearn.learning_curve import learning_curve from sklearn.datasets import load_digits from sklearn.svm import SVC digits = load_digits() svc = SVC() x_data = digits.data y_data = digits.target train_size,train_loss,test_loss = learning_curve(SVC(gamma=0.001),x_data,y_data,cv=10,scoring=‘accuracy‘,train_sizes=[0.1,0.25,0.5,0.75,1]) train_loss_mean = train_loss.mean(axis=1) test_loss_mean = test_loss.mean(axis=1) plt.plot(train_size,-train_loss_mean,‘r-o‘,label=‘train_loss‘) plt.plot(train_size,-test_loss_mean,‘g-o‘,label=‘test_loss‘) plt.legend()
from sklearn.learning_curve import validation_curve from sklearn.datasets import load_digits from sklearn.svm import SVC digits = load_digits() x_data = digits.data y_data = digits.target train_loss,test_loss = validation_curve(SVC(),x_data,y_data,param_name=‘gamma‘,param_range=np.logspace(-6,-2,5),cv=10,scoring=‘accuracy‘) train_loss_mean = train_loss.mean(axis=1) test_loss_mean = test_loss.mean(axis=1) plt.figure() plt.plot(np.logspace(-6,-2,5),-train_loss_mean,‘r-o‘,label=‘train_loss‘) plt.plot(np.logspace(-6,-2,5),-test_loss_mean,‘g-o‘,label=‘test_loss‘) plt.legend()
标签:验证 div ace alt logs ati svm bubuko ring
原文地址:https://www.cnblogs.com/slowlyslowly/p/8856711.html