码迷,mamicode.com
首页 > 其他好文 > 详细

CIFAR10自定义网络实战

时间:2019-05-25 18:14:48      阅读:107      评论:0      收藏:0      [点我收藏+]

标签:self   batch   log   port   网络   mamicode   add   cat   net   

CIFAR10

技术图片

MyDenseLayer

技术图片

import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow import keras

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


def preprocess(x, y):
    # [0, 255] --> [-1,1]
    x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1
    y = tf.cast(y, dtype=tf.int32)

    return x, y


batch_size = 128
# x --> [32,32,3], y --> [10k, 1]
(x, y), (x_val, y_val) = datasets.cifar10.load_data()
y = tf.squeeze(y)  # [10k, 1] --> [10k]
y_val = tf.squeeze(y_val)
y = tf.one_hot(y, depth=10)  # [50k, 10]
y_val = tf.one_hot(y_val, depth=10)  # [10k, 10]
print('datasets:', x.shape, y.shape, x_val.shape, y_val.shape, x.min(),
      x.max())

train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batch_size)

sample = next(iter(train_db))
print('batch:', sample[0].shape, sample[1].shape)


class MyDense(layers.Layer):
    # to replace standard layers.Dense()
    def __init__(self, inp_dim, outp_dim):
        super(MyDense, self).__init__()

        self.kernel = self.add_variable('w', [inp_dim, outp_dim])


#         self.bias = self.add_variable('b', [outp_dim])

    def call(self, inputs, training=None):
        x = inputs @ self.kernel
        return x


class MyNetwork(keras.Model):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.fc1 = MyDense(32 * 32 * 3, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, training=None):
        """inputs: [b,32,32,32,3]"""
        x = tf.reshape(inputs, [-1, 32 * 32 * 3])
        # [b,32*32*32] --> [b, 256]
        x = self.fc1(x)
        x = tf.nn.relu(x)
        # [b, 256] --> [b,128]
        x = self.fc2(x)
        x = tf.nn.relu(x)
        # [b, 128] --> [b,64]
        x = self.fc3(x)
        x = tf.nn.relu(x)
        # [b, 64] --> [b,32]
        x = self.fc4(x)
        x = tf.nn.relu(x)
        # [b, 32] --> [b,10]
        x = self.fc5(x)

        return x


network = MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=1)

network.evaluate(test_db)
network.save_weights('weights.ckpt')
del network
print('saved to ckpt/weights.ckpt')

network = MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metircs=['accuracy'])
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=1)
network.load_weights('weights.ckpt')
print('loaded weights from file.')

network.evaluate(test_db)
datasets: (50000, 32, 32, 3) (50000, 10) (10000, 32, 32, 3) (10000, 10) 0 255
batch: (128, 32, 32, 3) (128, 10)
Epoch 1/5
391/391 [==============================] - 7s 19ms/step - loss: 1.7276 - accuracy: 0.3358 - val_loss: 1.5801 - val_accuracy: 0.4427
Epoch 2/5
391/391 [==============================] - 7s 18ms/step - loss: 1.5045 - accuracy: 0.4606 - val_loss: 1.4808 - val_accuracy: 0.4812
Epoch 3/5
391/391 [==============================] - 6s 17ms/step - loss: 1.3919 - accuracy: 0.5019 - val_loss: 1.4596 - val_accuracy: 0.4921
Epoch 4/5
391/391 [==============================] - 7s 18ms/step - loss: 1.3039 - accuracy: 0.5364 - val_loss: 1.4651 - val_accuracy: 0.4950
Epoch 5/5
391/391 [==============================] - 6s 16ms/step - loss: 1.2270 - accuracy: 0.5622 - val_loss: 1.4483 - val_accuracy: 0.5030
79/79 [==============================] - 1s 11ms/step - loss: 1.4483 - accuracy: 0.5030
saved to ckpt/weights.ckpt
Epoch 1/5
391/391 [==============================] - 7s 19ms/step - loss: 1.7216 - val_loss: 1.5773
Epoch 2/5
391/391 [==============================] - 10s 26ms/step - loss: 1.5010 - val_loss: 1.5111
Epoch 3/5
391/391 [==============================] - 8s 21ms/step - loss: 1.3868 - val_loss: 1.4657
Epoch 4/5
391/391 [==============================] - 8s 20ms/step - loss: 1.3021 - val_loss: 1.4586
Epoch 5/5
391/391 [==============================] - 7s 17ms/step - loss: 1.2276 - val_loss: 1.4583
loaded weights from file.
79/79 [==============================] - 1s 12ms/step - loss: 1.4483





1.4482733222502697

CIFAR10自定义网络实战

标签:self   batch   log   port   网络   mamicode   add   cat   net   

原文地址:https://www.cnblogs.com/nickchen121/p/10923333.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!