码迷,mamicode.com
首页 > 其他好文 > 详细

TensorFlow CNN

时间:2019-07-03 16:43:32      阅读:108      评论:0      收藏:0      [点我收藏+]

标签:label   argmax   连接   The   ict   NPU   reset   run   std   

简单的分类任务

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets(‘data/‘, one_hot=True)

# to be able to rerun the model without overwriting tf variables
tf.reset_default_graph()

num_classes = 10
batch_size = 64
num_train = 10000

X = tf.placeholder(tf.float32, [None, 28,28,1])
y = tf.placeholder(tf.float32, [None, num_classes])
#CONV2D -> RELU -> MAXPOOL -> CONV2D -> RELU -> MAXPOOL -> FLATTEN -> FULLYCONNECTED
W1 = tf.Variable(tf.random_normal(shape=[5,5,1,32], stddev=0.1))
b1 = tf.Variable(tf.constant(0.1, shape=[32]))
W2 = tf.Variable(tf.random_normal(shape=[5,5,32,64], stddev=0.1))
b2 = tf.Variable(tf.constant(0.1, shape=[64]))
#卷积层
conv_1 = tf.nn.relu(tf.nn.conv2d(X, filter=W1, strides=[1,1,1,1], padding="SAME") + b1)
pool_1 = tf.nn.max_pool(conv_1, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME")
conv_2 = tf.nn.relu(tf.nn.conv2d(pool_1, filter=W2, strides=[1,1,1,1], padding="SAME") + b2)
pool_2 = tf.nn.max_pool(conv_2, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME")

#全连接层
W_fc1 = tf.Variable(tf.random_normal(shape=[7*7*64, 1024], stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
W_fc2 = tf.Variable(tf.random_normal(shape=[1024, num_classes], stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1, shape=[num_classes]))

pool_2 = tf.reshape(pool_2, [-1, 7*7*64])
fc1 = tf.nn.relu(tf.matmul(pool_2, W_fc1) + b_fc1)
fc2 = tf.matmul(fc1, W_fc2) + b_fc2 #the output of the last LINEAR unit

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=fc2))
train_op = tf.train.AdamOptimizer().minimize(loss)
correct_prediction = tf.equal(tf.argmax(fc2, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(num_train):
        mini_batch = mnist.train.next_batch(batch_size)
        X_temp = mini_batch[0].reshape([batch_size, 28,28,1])
        y_temp = mini_batch[1]
        sess.run(train_op, feed_dict={X:X_temp, y:y_temp})
        if step % 1000 == 0:
            loss_var, accuracy_var = sess.run([loss,accuracy], feed_dict={X:X_temp, y:y_temp})
            print("loss:", loss_var, "accuracy:", accuracy_var)

TensorFlow CNN

标签:label   argmax   连接   The   ict   NPU   reset   run   std   

原文地址:https://blog.51cto.com/5669384/2416716

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!