标签:inpu errors res tar error [1] targe pytho label
本实验使用mnist数据集完成手写数字识别的测试。识别正确率认为是95%
完整代码如下:
1 #!/usr/bin/env python 2 # coding: utf-8 3 4 # In[1]: 5 6 7 import numpy 8 import scipy.special 9 import matplotlib.pyplot 10 11 12 # In[2]: 13 14 15 class neuralNetwork: 16 def __init__(self, inputNodes, hiddenNodes, outputNodes,learningRate): 17 self.iNodes = inputNodes 18 self.oNodes = outputNodes 19 self.hNodes = hiddenNodes 20 self.lr = learningRate 21 self.wih = numpy.random.normal (0.0, pow(self.hNodes,-0.5), (self.hNodes, self.iNodes)) 22 self.who = numpy.random.normal (0.0, pow(self.oNodes,-0.5), (self.oNodes, self.hNodes)) 23 24 self.activation_function = lambda x: scipy.special.expit(x) 25 #print(self.wih) 26 pass 27 28 def train(self,inputs_list, target_list): 29 inputs = numpy.array(inputs_list, ndmin=2).T 30 targets = numpy.array(target_list, ndmin=2).T 31 #print(inputs) 32 #print(targets) 33 hidden_inputs = numpy.dot(self.wih,inputs) 34 #print(self.wih.shape) 35 #print(inputs.shape) 36 hidden_outputs = self.activation_function(hidden_inputs) 37 #print(hidden_inputs) 38 final_inputs = numpy.dot(self.who,hidden_outputs) 39 #print(hidden_outputs) 40 final_outputs = self.activation_function(final_inputs) 41 42 output_errors = targets - final_outputs 43 hidden_errors = numpy.dot(self.who.T,output_errors) 44 self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)),numpy.transpose(hidden_outputs)) 45 self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),numpy.transpose(inputs)) 46 pass 47 48 def query(self, inputs_list): 49 inputs = numpy.array(inputs_list, ndmin=2).T 50 hidden_inputs = numpy.dot(self.wih,inputs) 51 hidden_outputs = self.activation_function(hidden_inputs) 52 final_inputs = numpy.dot(self.who,hidden_outputs) 53 final_outpus = self.activation_function(final_inputs) 54 return final_outpus 55 pass 56 57 58 59 # In[3]: 60 61 62 inputNodes = 784 63 outputNodes = 10 64 hiddenNodes = 100 65 learningRate = 0.1 66 nN = neuralNetwork(inputNodes, hiddenNodes, outputNodes, learningRate) 67 68 69 # In[4]: 70 71 72 data_file = open("mnist_train.csv",‘r‘) 73 data_list = data_file.readlines() 74 data_file.close() 75 76 77 # In[5]: 78 79 80 epochs = 1 81 for e in range(epochs) : 82 for record in data_list: 83 all_values = record.split(‘,‘) 84 inputs = numpy.asfarray( all_values [1:])/255.0*0.99+0.01 85 targets = numpy.zeros(outputNodes) + 0.01 86 targets[int (all_values[0])] = 0.99 87 nN.train(inputs,targets) 88 pass 89 pass 90 91 92 # In[6]: 93 94 95 test_data_file = open("mnist_test.csv",‘r‘) 96 test_data_list = test_data_file.readlines() 97 test_data_file.close() 98 99 100 # In[7]: 101 102 103 scorecard = [] 104 for record in test_data_list: 105 all_values = record.split(‘,‘) 106 correct_label = int(all_values[0]) 107 inputs = numpy.asfarray( all_values [1:])/255.0*0.99+0.01 108 outputs = nN.query(inputs) 109 label = numpy.argmax(outputs) 110 if(label == correct_label): 111 scorecard.append(1) 112 else: 113 scorecard.append(0) 114 pass 115 pass 116 117 118 # In[8]: 119 120 121 scorecard_array = numpy.asarray(scorecard) 122 print ("performance = " ,scorecard_array.sum()/scorecard_array.size) 123 124 125 # In[9]: 126 127 128 import scipy.misc 129 img_array = scipy.misc.imread(‘test.png‘,flatten="True") 130 img_data = 255.0 - img_array . reshape(784) 131 img_data = (img_data /255.0 * 0.99 ) + 0.01 132 op=nN.query(img_data) 133 print(op) 134 print(numpy.argmax(op)) 135 136 137 # In[10]: 138 139 140 all_values = data_list[1].split(‘,‘) 141 image_array = numpy.asfarray( all_values [1:]).reshape((28,28)) 142 matplotlib.pyplot.imshow(image_array, cmap = ‘Greys‘,interpolation=‘None‘)
IN[9]到IN[10]的代码分别用于测试自己制作的数字识别效果和显示图像。
代码运行过程需要mnist数据集,链接:https://pan.baidu.com/s/120GTdZ8Tivkp1KD9VQ_XeQ
标签:inpu errors res tar error [1] targe pytho label
原文地址:https://www.cnblogs.com/bai2018/p/10353747.html