码迷,mamicode.com
首页 > 其他好文 > 详细

采用libsvm进行mnist训练

时间:2016-08-06 14:18:28      阅读:353      评论:0      收藏:0      [点我收藏+]

标签:

 1 #coding:utf8
 2 import cPickle
 3 import gzip
 4 import numpy as np
 5 from sklearn.svm import libsvm
 6 
 7 
 8 class SVM(object):
 9     def __init__(self, kernel=rbf, degree=3, gamma=auto,
10                  coef0=0.0, tol=1e-3, C=1.0,nu=0., epsilon=0.,shrinking=True, probability=False,
11                   cache_size=200, class_weight=None, max_iter=-1):
12         self.kernel = kernel
13         self.degree = degree
14         self.gamma = gamma
15         self.coef0 = coef0
16         self.tol = tol
17         self.C = C
18         self.nu = nu
19         self.epsilon = epsilon
20         self.shrinking = shrinking
21         self.probability = probability
22         self.cache_size = cache_size
23         self.class_weight = class_weight
24         self.max_iter = max_iter
25 
26     def fit(self, X, y):
27         X= np.array(X, dtype=np.float64, order=C)
28         cls, y = np.unique(y, return_inverse=True)
29         weight = np.ones(cls.shape[0], dtype=np.float64, order=C)
30         self.class_weight_=weight
31         self.classes_ = cls
32         y= np.asarray(y, dtype=np.float64, order=C)
33         sample_weight = np.asarray([])
34         solver_type =0
35         self._gamma = 1.0 / X.shape[1]
36         kernel = self.kernel
37         seed = np.random.randint(np.iinfo(i).max)
38         self.support_, self.support_vectors_, self.n_support_, 39             self.dual_coef_, self.intercept_, self.probA_, 40             self.probB_, self.fit_status_ = libsvm.fit(
41                 X, y,
42                 svm_type=solver_type, sample_weight=sample_weight,
43                 class_weight=self.class_weight_, kernel=kernel, C=self.C,
44                 nu=self.nu, probability=self.probability, degree=self.degree,
45                 shrinking=self.shrinking, tol=self.tol,
46                 cache_size=self.cache_size, coef0=self.coef0,
47                 gamma=self._gamma, epsilon=self.epsilon,
48                 max_iter=self.max_iter, random_seed=seed)
49         self.shape_fit_ = X.shape
50         self._intercept_ = self.intercept_.copy()
51         self._dual_coef_ = self.dual_coef_
52         self.intercept_ *= -1
53         self.dual_coef_ = -self.dual_coef_
54         return self
55 
56     def predict(self, X):
57         X= np.array(X,dtype=np.float64, order=C)
58         svm_type = 0
59         return libsvm.predict(
60             X, self.support_, self.support_vectors_, self.n_support_,
61             self._dual_coef_, self._intercept_,
62             self.probA_, self.probB_, svm_type=svm_type, kernel=self.kernel,
63             degree=self.degree, coef0=self.coef0, gamma=self._gamma,
64             cache_size=self.cache_size)
65 
66 def load_data():
67     f = gzip.open(../data/mnist.pkl.gz, rb)
68     training_data, validation_data, test_data = cPickle.load(f)
69     f.close()
70     return (training_data, validation_data, test_data)
71 
72 def svm_test():
73     training_data, validation_data, test_data = load_data()
74     clf = SVM(kernel=linear)   # ‘linear‘, ‘poly‘, ‘rbf‘, ‘sigmoid‘, ‘precomputed‘
75     clf.fit(training_data[0][:100], training_data[1][:100])
76     predictions = [int(a) for a in clf.predict(test_data[0][:100])]
77     num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1][:100]))
78     print "Baseline classifier using an SVM."
79     print "%s of %s values correct." % (num_correct, len(test_data[1][:100]))   # 0.9172  ‘rbf‘=0.9214
80 
81 if __name__ == "__main__":
82     svm_test()

 

采用libsvm进行mnist训练

标签:

原文地址:http://www.cnblogs.com/qw12/p/5743865.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!