标签:
1 # coding:utf8 2 import numpy as np 3 import cPickle 4 import os 5 import tensorflow as tf 6 7 class SoftMax: 8 def __init__(self,MAXT=30,step=0.0025): 9 self.MAXT = MAXT 10 self.step = step 11 12 def load_theta(self,datapath="data/softmax.pkl"): 13 self.theta = cPickle.load(open(datapath,‘rb‘)) 14 15 def process_train(self,data,label,typenum=10,batch_size=500): 16 batches = data.shape[0] / batch_size 17 valuenum=data.shape[1] 18 if len(label.shape)==1: 19 label=self.reshape_data(label,typenum) 20 x = tf.placeholder("float", [None,valuenum]) 21 theta = tf.Variable(tf.zeros([valuenum,typenum])) 22 y = tf.nn.softmax(tf.matmul(x,theta)) 23 y_ = tf.placeholder("float", [None, typenum]) 24 cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #交叉熵 25 train_step = tf.train.GradientDescentOptimizer(self.step).minimize(cross_entropy) 26 init = tf.initialize_all_variables() 27 sess = tf.Session() 28 sess.run(init) 29 for epoch in range(self.MAXT): 30 cost_=[] 31 for index in xrange(batches): 32 c_,_=sess.run([cross_entropy,train_step], feed_dict={ x: data[index * batch_size: (index + 1) * batch_size], 33 y_: label[index * batch_size: (index + 1) * batch_size]}) 34 cost_.append(c_) 35 if epoch % 5 == 0: 36 print(( ‘epoch %i, minibatch %i/%i,averange cost is %f‘) % 37 (epoch,index + 1,batches,np.mean(cost_))) 38 self.theta=sess.run(theta) 39 if not os.path.exists(‘data/softmax.pkl‘): 40 f= open("data/softmax.pkl",‘wb‘) 41 cPickle.dump(self.theta,f) 42 f.close() 43 return self.theta 44 45 46 def process_test(self,data,label,typenum=10): 47 valuenum=data.shape[1] 48 if len(label.shape)==1: 49 label=self.reshape_data(label,typenum) 50 x = tf.placeholder("float", [None,valuenum]) 51 theta = self.theta 52 y = tf.nn.softmax(tf.matmul(x,theta)) 53 y_ = tf.placeholder("float", [None, typenum]) 54 init = tf.initialize_all_variables() 55 sess = tf.Session() 56 sess.run(init) 57 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 58 accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 59 print "Accuracy: ",sess.run(accuracy, feed_dict={x: data,y_: label}) 60 61 def h(self,x): 62 m = np.exp(np.dot(x,self.theta)) 63 sump = np.sum(m,axis=1) 64 return m/sump 65 66 def predict(self,x): 67 return np.argmax(self.h(x),axis=1) 68 69 def reshape_data(self,label,typenum): 70 label_=[] 71 for yl_ in label: 72 tl_=np.zeros(typenum) 73 tl_[yl_]=1.0 74 label_.append(tl_) 75 return np.mat(label_) 76 77 if __name__ == ‘__main__‘: 78 f = open(‘mnist.pkl‘, ‘rb‘) 79 training_data, validation_data, test_data = cPickle.load(f) 80 training_inputs = [np.reshape(x, 784) for x in training_data[0]] 81 data = np.array(training_inputs) 82 training_inputs = [np.reshape(x, 784) for x in validation_data[0]] 83 vdata = np.array(training_inputs) 84 f.close() 85 86 softmax = SoftMax() 87 softmax.process_train(data,training_data[1]) 88 softmax.process_test(vdata,validation_data[1]) #Accuracy: 0.9269 89 softmax.process_test(data,training_data[1]) #Accuracy: 0.92718
标签:
原文地址:http://www.cnblogs.com/qw12/p/5962430.html