码迷,mamicode.com
首页 > 其他好文 > 详细

15、cifar10

时间:2019-12-23 22:22:08      阅读:106      评论:0      收藏:0      [点我收藏+]

标签:level   mode   environ   style   variable   log   amp   ini   import   

  1 import  tensorflow as tf
  2 from    tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
  3 from     tensorflow import keras
  4 import  os
  5 
  6 os.environ[TF_CPP_MIN_LOG_LEVEL] = 2
  7 
  8 
  9 def preprocess(x, y):
 10     # [0~255] => [-1~1]
 11     x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1.
 12     y = tf.cast(y, dtype=tf.int32)
 13     return x,y
 14 
 15 
 16 batchsz = 128
 17 # [50k, 32, 32, 3], [10k, 1]
 18 (x, y), (x_val, y_val) = datasets.cifar10.load_data()
 19 y = tf.squeeze(y)
 20 y_val = tf.squeeze(y_val)
 21 y = tf.one_hot(y, depth=10) # [50k, 10]
 22 y_val = tf.one_hot(y_val, depth=10) # [10k, 10]
 23 print(datasets:, x.shape, y.shape, x_val.shape, y_val.shape, x.min(), x.max())
 24 
 25 
 26 train_db = tf.data.Dataset.from_tensor_slices((x,y))
 27 train_db = train_db.map(preprocess).shuffle(10000).batch(batchsz)
 28 test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
 29 test_db = test_db.map(preprocess).batch(batchsz)
 30 
 31 
 32 sample = next(iter(train_db))
 33 print(batch:, sample[0].shape, sample[1].shape)
 34 
 35 
 36 class MyDense(layers.Layer):
 37     # to replace standard layers.Dense()
 38     def __init__(self, inp_dim, outp_dim):
 39         super(MyDense, self).__init__()
 40 
 41         self.kernel = self.add_variable(w, [inp_dim, outp_dim])
 42         # self.bias = self.add_variable(‘b‘, [outp_dim])
 43 
 44     def call(self, inputs, training=None):
 45 
 46         x = inputs @ self.kernel
 47         return x
 48 
 49 class MyNetwork(keras.Model):
 50 
 51     def __init__(self):
 52         super(MyNetwork, self).__init__()
 53 
 54         self.fc1 = MyDense(32*32*3, 256)
 55         self.fc2 = MyDense(256, 128)
 56         self.fc3 = MyDense(128, 64)
 57         self.fc4 = MyDense(64, 32)
 58         self.fc5 = MyDense(32, 10)
 59 
 60 
 61 
 62     def call(self, inputs, training=None):
 63         """
 64 
 65         :param inputs: [b, 32, 32, 3]
 66         :param training:
 67         :return:
 68         """
 69         x = tf.reshape(inputs, [-1, 32*32*3])
 70         # [b, 32*32*3] => [b, 256]
 71         x = self.fc1(x)
 72         x = tf.nn.relu(x)
 73         # [b, 256] => [b, 128]
 74         x = self.fc2(x)
 75         x = tf.nn.relu(x)
 76         # [b, 128] => [b, 64]
 77         x = self.fc3(x)
 78         x = tf.nn.relu(x)
 79         # [b, 64] => [b, 32]
 80         x = self.fc4(x)
 81         x = tf.nn.relu(x)
 82         # [b, 32] => [b, 10]
 83         x = self.fc5(x)
 84 
 85         return x
 86 
 87 
 88 network = MyNetwork()
 89 network.compile(optimizer=optimizers.Adam(lr=1e-3),
 90                 loss=tf.losses.CategoricalCrossentropy(from_logits=True),
 91                 metrics=[accuracy])
 92 network.fit(train_db, epochs=15, validation_data=test_db, validation_freq=1)
 93 
 94 network.evaluate(test_db)
 95 network.save_weights(ckpt/weights.ckpt)
 96 del network
 97 print(saved to ckpt/weights.ckpt)
 98 
 99 
100 network = MyNetwork()
101 network.compile(optimizer=optimizers.Adam(lr=1e-3),
102                 loss=tf.losses.CategoricalCrossentropy(from_logits=True),
103                 metrics=[accuracy])
104 network.load_weights(ckpt/weights.ckpt)
105 print(loaded weights from file.)
106 network.evaluate(test_db)

15、cifar10

标签:level   mode   environ   style   variable   log   amp   ini   import   

原文地址:https://www.cnblogs.com/pengzhonglian/p/12088664.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!