码迷,mamicode.com
首页 > 其他好文 > 详细

基于Numpy的神经网络+手写数字识别

时间:2019-06-04 13:22:24      阅读:102      评论:0      收藏:0      [点我收藏+]

标签:actual   lock   pow   use   lam   code   scipy   plt   from   

基于Numpy的神经网络+手写数字识别

本文代码来自Tariq Rashid所著《Python神经网络编程》

代码分为三个部分,框架如下所示:

# neural network class definition
class neuralNetwork:
    
    # initialise the neural network
    def __init__():
        pass
    
    # train the neural network
    def train():
        pass
    
    # query the neural network
    def query():
        pass

这是一个坚实的框架,可以在这个框架之上,充实神经网络工作的详细细节。

import numpy as np
import scipy.special
import matplotlib.pyplot as plt

#neural network class definition
class neuralNetwork :
    
    #initialise the neural network
    def __init__(self, inputNodes, hiddenNodes, outputNodes, learningrate) :
        #set number of nodes in each input, hidden, output layer
        self.inodes = inputNodes
        self.hnodes = hiddenNodes
        self.onodes = outputNodes
        
        #learning rate
        self.lr = learningrate
        
        #link weight matrices, wih and who
        self.wih = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
        self.who = np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
        
        #activation function is the sigmoid function
        self.activation_function = lambda x : scipy.special.expit(x)
        pass
    
     # train the neural network
    def train(self, inputs_list, targets_list):
        #convert inputs_list, targets_list to 2d array
        inputs = np.array(inputs_list, ndmin=2).T
        targets = np.array(targets_list, ndmin=2).T
        
        #calculate signals into hidden layer
        hidden_inputs = np.dot(self.wih, inputs)
        #calculate the signals emerging from hidden layer
        hidden_outputs = self.activation_function(hidden_inputs)
        
        #calculate signals into final output layer
        final_inputs = np.dot(self.who, hidden_outputs)
        #calculate the signals emerging from final output layer
        final_outputs = self.activation_function(final_inputs)
        
        #output layer error is the (target-actual)
        output_errors = targets -  final_outputs
        #hidden layer error is the output_errors, split by weights, recombined at hidden nodes
        hidden_errors = np.dot(self.who.T, output_errors)
        
        #update the weights for the links between the hidden and output layers
        self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)), np.transpose(hidden_outputs))
        #update the weights for the links between the input and hidden layers
        self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))
        
        pass
    
    # query the neural network
    def query(self, inputs_list):
        #convert inputs_list to 2d array
        inputs = np.array(inputs_list, ndmin=2).T
        
        #calculate signals into hidden layer
        hidden_inputs = np.dot(self.wih, inputs)
        #calculate the signals emerging from hidden layer
        hidden_outputs = self.activation_function(hidden_inputs)
        
        #calculate signals into final output layer
        final_inputs = np.dot(self.who, hidden_outputs)
        #calculate the signals emerging from final output layer
        final_outputs = self.activation_function(final_inputs)
        
        return final_outputs
    
    pass

使用以上定义的神经网络类:

#number of input,hidden and output nodes
input_nodes = 784
hidden_nodes = 200
output_nodes = 10

#learning rate is 0.1
learning_rate = 0.1

#create instance of neural network
n = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)

#load the minist training data CSV file into a list
training_data_file = open("mnist_dataset/mnist_train.csv", "r")
training_data_list = training_data_file.readlines()
training_data_file.close()

#train the neural network

#epochs is the number of times the training data set is used for training
epochs = 5

for e in range(epochs):
    #go through all records in the training data set
    for record in training_data_list:
        #split the record by the "," commas
        all_values = record.split(",")
        #scale and shift the inputs
        inputs = (np.asfarray(all_values[1:])/255.0*0.99) + 0.01
        #create the target output values (all 0.01, except the desired label which is 0.99)
        targets = np.zeros(output_nodes) + 0.01
        #all_values[0] is the target label for this record
        targets[int(all_values[0])] = 0.99
        n.train(inputs, targets)
        pass
    pass

#load the minist test data CSV file into a list
test_data_file = open("mnist_dataset/mnist_test.csv", 'r')
test_data_list = test_data_file.readlines()
test_data_file.close()


#test the neural network
#scorecard for how well the network performs, initially empty
scorecard = []

#go through all the records in the test data set
for record in test_data_list:
    #split the record by the ',' commas
    all_values = record.split(',')
    #correct answer is the first value
    correct_label = int(all_values[0])
    #scale and shift the inputs
    inputs = (np.asfarray(all_values[1:])/255.0*0.99) + 0.01
    #query the network
    outputs = n.query(inputs)
    #the index of the highest value corresponds to the label
    label = np.argmax(outputs)
    #append correct or incorrect to list
    if(label == correct_label):
        #network's answer matches correct answer, add 1 to scorecard
        scorecard.append(1)
    else:
        #network's answer doesn't matche correct answer, add 0 to scorecard
        scorecard.append(0)
        pass
    
    pass

#calculate the performance score, the fraction of correct answers
scorecard_array = np.asarray(scorecard)
print("performance = ", scorecard_array.sum()/scorecard_array.size)

以上训练中所用到的数据集:

训练集

测试集

基于Numpy的神经网络+手写数字识别

标签:actual   lock   pow   use   lam   code   scipy   plt   from   

原文地址:https://www.cnblogs.com/xxxxxxxxx/p/10972614.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!