码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow Logistic

时间:2019-07-02 00:26:08      阅读:97      评论:0      收藏:0      [点我收藏+]

标签:sha   oss   amp   input   max   loss   des   step   调试   

写代码真的要小心的,小问题调试半天。。。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets(‘data/‘, one_hot=True)

num_classes = 10
input_size = 784
train_iter = 50000
batch_size = 64

X = tf.placeholder(tf.float32, shape=[None, input_size])
y = tf.placeholder(tf.float32, shape=[None, num_classes])

W = tf.Variable(tf.random_normal([input_size, num_classes], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[num_classes]))

y_pred = tf.nn.softmax(tf.matmul(X, W) + b)
loss = tf.reduce_mean(tf.square(y - y_pred))
train = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss)

correct_pred = tf.equal(tf.argmax(y_pred,1), tf.argmax(y,1)) #通过最大值的位置是否一支来判断结果是否判断正确
accuracy = tf.reduce_mean(tf.cast(correct_pred, "float"))#求平均值,就是正确率

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(train_iter):
        batch = mnist.train.next_batch(batch_size)
        X_temp = batch[0]
        y_temp = batch[1]
        sess.run(train, feed_dict={X:X_temp, y:y_temp})
        if step % 1000 == 0:
            train_accu = sess.run(accuracy, feed_dict={X:X_temp, y:y_temp})
            print(train_accu)

Tensorflow Logistic

标签:sha   oss   amp   input   max   loss   des   step   调试   

原文地址:https://blog.51cto.com/5669384/2416040

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!