码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow学习3---mnist

时间:2018-04-20 11:42:21      阅读:295      评论:0      收藏:0      [点我收藏+]

标签:dice   esc   function   ons   step   run   bat   取数据   hot   

 1 import tensorflow as tf 
 2 from tensorflow.examples.tutorials.mnist import input_data
 3 
 4 ‘‘‘数据下载‘‘‘
 5 mnist=input_data.read_data_sets(Mnist_data,one_hot=True)
 6 #one_hot标签
 7       
 8 ‘‘‘生成层 函数‘‘‘
 9 def add_layer(input,in_size,out_size,n_layer=layer,activation_function=None):
10     layer_name=layer %s % n_layer
11     with tf.name_scope(weights):
12         Weights=tf.Variable(tf.random_normal([in_size,out_size]),name=w)
13         tf.summary.histogram(layer_name+/wights,Weights)
14         #tf.summary.histogram:output summary with histogram直方图
15     with tf.name_scope(biases):
16         biases=tf.Variable(tf.zeros([1,out_size])+0.1)
17         tf.summary.histogram(layer_name+/biases,biases)
18         #tf.summary.histogram:
19     with tf.name_scope(Wx_plus_b):
20         Wx_plus_b=tf.matmul(input,Weights)+biases
21     if activation_function==None:
22         outputs=Wx_plus_b
23     else:
24         outputs=activation_function(Wx_plus_b)
25     tf.summary.histogram(layer_name+/output,outputs)
26     return outputs
27 ‘‘‘准确率‘‘‘
28 def compute_accuracy(v_xs,v_ys):
29     global prediction
30     y_pre=sess.run(prediction,feed_dict={xs:v_xs})#<
31     #tf.equal()对比预测值的索引和实际label的索引是否一样,一样返回True,否则返回false
32     correct_prediction=tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
33     #correct_prediction-->[ True False  True ...,  True  True  True]
34     ‘‘‘补充知识-tf.argmax‘‘‘
35     #tf.argmax:Returns the index with the largest value across dimensions of a tensor.
36     #tf.argmax()----->
37     accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
38     #正确cast为1,错误cast为0
39     ‘‘‘补充知识 tf.cast‘‘‘
40     #tf.cast:   Casts a tensor to a new type.
41     ## tensor `a` is [1.8, 2.2], dtype=tf.float
42     #tf.cast(a, tf.int32) ==> [1, 2]  # dtype=tf.int32
43     result=sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys})
44     print(sess.run(correct_prediction,feed_dict={xs:v_xs,ys:v_ys}))
45     ckc=tf.cast(correct_prediction,tf.float32)
46     print(sess.run(ckc,feed_dict={xs:v_xs,ys:v_ys}))
47     return result
48 
49 
50 ‘‘‘占位符‘‘‘
51 xs=tf.placeholder(tf.float32,[None,784])
52 ys=tf.placeholder(tf.float32,[None,10])
53 
54 ‘‘‘添加层‘‘‘
55 
56 prediction=add_layer(xs,784,10,activation_function=tf.nn.softmax)
57 #sotmax激活函数,用于分类函数
58 
59 ‘‘‘计算‘‘‘
60 #交叉熵cross_entropy损失函数,参数分别为实际的预测值和实际的label值y,re
61 ‘‘‘补充知识‘‘‘
62 #reduce_mean()
63 # ‘x‘ is [[1., 1. ]]
64 #         [2., 2.]]
65 #tf.reduce_mean(x) ==> 1.5
66 #tf.reduce_mean(x, 0) ==> [1.5, 1.5]
67 #tf.reduce_mean(x, 1) ==> [1.,  2.]
68 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1]))
69 ‘‘‘补充知识‘‘‘
70 #reduce_sum
71 # ‘x‘ is [[1, 1, 1]]
72 #         [1, 1, 1]]
73 #tf.reduce_sum(x) ==> 6
74 #tf.reduce_sum(x, 0) ==> [2, 2, 2]
75 #tf.reduce_sum(x, 1) ==> [3, 3]
76 #tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]
77 #tf.reduce_sum(x, [0, 1]) ==> 6
78 
79 train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
80 
81 ‘‘‘Session_begin‘‘‘
82 with tf.Session() as sess:
83     sess.run(tf.global_variables_initializer())
84     for i in range(1000):
85         batch_xs,batch_ys=mnist.train.next_batch(100) #逐个batch去取数据
86         sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
87         if(i%50==0):
88             print(compute_accuracy(mnist.test.images,mnist.test.labels))
89             

 

tensorflow学习3---mnist

标签:dice   esc   function   ons   step   run   bat   取数据   hot   

原文地址:https://www.cnblogs.com/ChenKe-cheng/p/8889229.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!