标签:
1 #coding:utf8 2 import cPickle 3 import gzip 4 import numpy as np 5 from sklearn.svm import libsvm 6 7 8 class SVM(object): 9 def __init__(self, kernel=‘rbf‘, degree=3, gamma=‘auto‘, 10 coef0=0.0, tol=1e-3, C=1.0,nu=0., epsilon=0.,shrinking=True, probability=False, 11 cache_size=200, class_weight=None, max_iter=-1): 12 self.kernel = kernel 13 self.degree = degree 14 self.gamma = gamma 15 self.coef0 = coef0 16 self.tol = tol 17 self.C = C 18 self.nu = nu 19 self.epsilon = epsilon 20 self.shrinking = shrinking 21 self.probability = probability 22 self.cache_size = cache_size 23 self.class_weight = class_weight 24 self.max_iter = max_iter 25 26 def fit(self, X, y): 27 X= np.array(X, dtype=np.float64, order=‘C‘) 28 cls, y = np.unique(y, return_inverse=True) 29 weight = np.ones(cls.shape[0], dtype=np.float64, order=‘C‘) 30 self.class_weight_=weight 31 self.classes_ = cls 32 y= np.asarray(y, dtype=np.float64, order=‘C‘) 33 sample_weight = np.asarray([]) 34 solver_type =0 35 self._gamma = 1.0 / X.shape[1] 36 kernel = self.kernel 37 seed = np.random.randint(np.iinfo(‘i‘).max) 38 self.support_, self.support_vectors_, self.n_support_, 39 self.dual_coef_, self.intercept_, self.probA_, 40 self.probB_, self.fit_status_ = libsvm.fit( 41 X, y, 42 svm_type=solver_type, sample_weight=sample_weight, 43 class_weight=self.class_weight_, kernel=kernel, C=self.C, 44 nu=self.nu, probability=self.probability, degree=self.degree, 45 shrinking=self.shrinking, tol=self.tol, 46 cache_size=self.cache_size, coef0=self.coef0, 47 gamma=self._gamma, epsilon=self.epsilon, 48 max_iter=self.max_iter, random_seed=seed) 49 self.shape_fit_ = X.shape 50 self._intercept_ = self.intercept_.copy() 51 self._dual_coef_ = self.dual_coef_ 52 self.intercept_ *= -1 53 self.dual_coef_ = -self.dual_coef_ 54 return self 55 56 def predict(self, X): 57 X= np.array(X,dtype=np.float64, order=‘C‘) 58 svm_type = 0 59 return libsvm.predict( 60 X, self.support_, self.support_vectors_, self.n_support_, 61 self._dual_coef_, self._intercept_, 62 self.probA_, self.probB_, svm_type=svm_type, kernel=self.kernel, 63 degree=self.degree, coef0=self.coef0, gamma=self._gamma, 64 cache_size=self.cache_size) 65 66 def load_data(): 67 f = gzip.open(‘../data/mnist.pkl.gz‘, ‘rb‘) 68 training_data, validation_data, test_data = cPickle.load(f) 69 f.close() 70 return (training_data, validation_data, test_data) 71 72 def svm_test(): 73 training_data, validation_data, test_data = load_data() 74 clf = SVM(kernel=‘linear‘) # ‘linear‘, ‘poly‘, ‘rbf‘, ‘sigmoid‘, ‘precomputed‘ 75 clf.fit(training_data[0][:100], training_data[1][:100]) 76 predictions = [int(a) for a in clf.predict(test_data[0][:100])] 77 num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1][:100])) 78 print "Baseline classifier using an SVM." 79 print "%s of %s values correct." % (num_correct, len(test_data[1][:100])) # 0.9172 ‘rbf‘=0.9214 80 81 if __name__ == "__main__": 82 svm_test()
标签:
原文地址:http://www.cnblogs.com/qw12/p/5743865.html