码迷,mamicode.com
首页 > 其他好文 > 详细

莫烦课程Batch Normalization 批标准化

时间:2017-12-13 20:34:50      阅读:205      评论:0      收藏:0      [点我收藏+]

标签:github   http   取出   数据   target   har   hub   bat   self   

 for i in range(N_HIDDEN):               # build hidden layers and BN layers
            input_size = 1 if i == 0 else 10
            fc = nn.Linear(input_size, 10)
            setattr(self, ‘fc%i‘ % i, fc)       # IMPORTANT set layer to the Module
            self._set_init(fc)                  # parameters initialization
            self.fcs.append(fc)
            if self.do_bn:
                bn = nn.BatchNorm1d(10, momentum=0.5)
                setattr(self, ‘bn%i‘ % i, bn)   # IMPORTANT set layer to the Module
self.bns.append(bn)

 上面的代码对每个隐层进行批标准化,setattr(self, ‘fc%i‘ % i, fc)作用相当于self.fci=fc

每次生成的结果append到bns的最后面,结果的size 10×10,取出这些数据是非常方便

def forward(self, x):
        pre_activation = [x]
        if self.do_bn: x = self.bn_input(x)     # input batch normalization
        layer_input = [x]
        for i in range(N_HIDDEN):
            x = self.fcs[i](x)
            pre_activation.append(x)
            if self.do_bn: x = self.bns[i](x)   # batch normalization
            x = ACTIVATION(x)
            layer_input.append(x)
        out = self.predict(x)
return out, layer_input, pre_activation

全部的源代码

 

莫烦课程Batch Normalization 批标准化

标签:github   http   取出   数据   target   har   hub   bat   self   

原文地址:http://www.cnblogs.com/lindaxin/p/8034069.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!