码迷,mamicode.com
首页 > 其他好文 > 详细

训练网络

时间:2018-09-23 22:24:16      阅读:230      评论:0      收藏:0      [点我收藏+]

标签:NPU   sci   设置   vat   tno   如何   nal   12px   tput   

  解决训练任务,包括两部分内容

第一部分:针对给定的训练样本计算输出。这与query()函数所做的工作没什么区别。

第二部分:将计算所得到的输出与期望的目标值做对比,使用差值来指导网络权重的更新。

其中,第一部分的代码如下所示:

 1     def train(self,input_list,target_list):
 2         # 转换输入输出列表到二维数组
 3         inputs = numpy.array(input_list, ndmin=2).T
 4         targets = numpy.array(target_list,ndmin= 2).T
 5         # 计算到隐藏层的信号
 6         hidden_inputs = numpy.dot(self.wih, inputs)
 7         # 计算隐藏层输出的信号
 8         hidden_outputs = self.activation_function(hidden_inputs)
 9         # 计算到输出层的信号
10         final_inputs = numpy.dot(self.who, hidden_outputs)
11         final_outputs = self.activation_function(final_inputs)

这部分与query()中的区别在于多了一个期望值,因为我们需要期望值来训练网络,所以这部分必不可少。

第二部分:

1.首先需要计算误差,也就是期望值减去输出的实际值,以此可表示为:

output_errors = targets - final_outputs

2.那么,如何根据得到的输出误差来更新隐藏层和输出层,输入层和隐藏层之间的权重呢?首先,输出的误差来源于隐藏层传播的误差,隐藏层各个节点的误差具体分配多少呢?

                    errorshidden = weightsThidden_output * errorsoutput

使用python上式可表示为:

hidden_errors = numpy.dot(self.who.T,output_errors)

 

其次,利用公式更新权重:

      ΔWj,k = α * Ek * sigmod(Ok) * (1 - sigmod(OK)) * OjT

使用python可表示为:

1 #隐藏层和输出层权重更新
2         self.who += self.lr * numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),
3                                         numpy.transpose(hidden_outputs))
4         #输入层和隐藏层权重更新
5         self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),
6                                         numpy.transpose(inputs))

因此,完整的神经网络代码表示为:

 1 import numpy
 2 import scipy.special
 3 
 4 # 神经网络类定义
 5 class NeuralNetwork():
 6     # 初始化神经网络
 7     def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
 8         # 设置输入层节点,隐藏层节点和输出层节点的数量
 9         self.inodes = inputnodes
10         self.hnodes = hiddennodes
11         self.onodes = outputnodes
12         # 学习率设置
13         self.lr = learningrate
14         # 权重矩阵设置 正态分布
15         self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
16         self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
17         # 激活函数设置,sigmod()函数
18         self.activation_function = lambda x: scipy.special.expit(x)
19         pass
20 
21     # 训练神经网络
22     def train(self,input_list,target_list):
23         # 转换输入输出列表到二维数组
24         inputs = numpy.array(input_list, ndmin=2).T
25         targets = numpy.array(target_list,ndmin= 2).T
26         # 计算到隐藏层的信号
27         hidden_inputs = numpy.dot(self.wih, inputs)
28         # 计算隐藏层输出的信号
29         hidden_outputs = self.activation_function(hidden_inputs)
30         # 计算到输出层的信号
31         final_inputs = numpy.dot(self.who, hidden_outputs)
32         final_outputs = self.activation_function(final_inputs)
33 
34         output_errors = targets - final_outputs
35         hidden_errors = numpy.dot(self.who.T,output_errors)
36 
37         #隐藏层和输出层权重更新
38         self.who += self.lr * numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),
39                                         numpy.transpose(hidden_outputs))
40         #输入层和隐藏层权重更新
41         self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),
42                                         numpy.transpose(inputs))
43         pass
44     # 查询神经网络
45     def query(self, input_list):
46         # 转换输入列表到二维数组
47         inputs = numpy.array(input_list, ndmin=2).T
48         # 计算到隐藏层的信号
49         hidden_inputs = numpy.dot(self.wih, inputs)
50         # 计算隐藏层输出的信号
51         hidden_outputs = self.activation_function(hidden_inputs)
52         # 计算到输出层的信号
53         final_inputs = numpy.dot(self.who, hidden_outputs)
54         final_outputs = self.activation_function(final_inputs)
55 
56         return final_outputs

 

训练网络

标签:NPU   sci   设置   vat   tno   如何   nal   12px   tput   

原文地址:https://www.cnblogs.com/carlber/p/9693600.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!