码迷,mamicode.com
首页 > 其他好文 > 详细

自定义网络搭建

时间:2020-01-31 17:17:17      阅读:89      评论:0      收藏:0      [点我收藏+]

标签:获取   super   code   class   layer   put   管理   color   com   

使用到的API有:keras.Sequential、Layers/Model

1.keras.Sequential

以前的代码已经很多次用到了这个接口,这里直接给出代码:

model = Sequential([
    layers.Dense(256,activation=tf.nn.relu), # [b,784] ==>[b,256]
    layers.Dense(128,activation=tf.nn.relu),
    layers.Dense(64,activation=tf.nn.relu),
    layers.Dense(32,activation=tf.nn.relu),
    layers.Dense(10)
])

model.build(input_shape=[None,28*28])
model.summary()

Sequential还可以通过一些API去管理参数,如:model.trainable_variables、model.call(),前者是用来获取网络中所有的可训练参数,后者则是相当于逐层调model方法

2.Layer/Model

Layer的全路径为keras.layers.Layer,Model的全路径为keras.Model(包含compile,fit,evaluate功能)

class MyDense(keras.layers.Layer):
    def __init__(self,inp_dim,outp_dim):
        super(MyDense, self).__init__()

        self.kernel = self.add_variable(w,[inp_dim,outp_dim])
        self.bias = self.add_variable(b,[outp_dim])

    def call(self,inputs,training=None):
        out = inputs @ self.kernel + self.bias

        return out
    

class MyModel(keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        
        self.fc1 = MyDense(28*28,256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)
        
    def call(self,inputs,training=None):
        x = self.fc1(inputs)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        
        return x

 

自定义网络搭建

标签:获取   super   code   class   layer   put   管理   color   com   

原文地址:https://www.cnblogs.com/zdm-code/p/12245906.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!