码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow的hello world

时间:2019-03-01 12:19:24      阅读:134      评论:0      收藏:0      [点我收藏+]

标签:hello   oss   end   滑动   learn   结构   col   mod   main   

import tensorflow as tf;
from tensorflow.examples.tutorials.mnist import input_data


##定义网络结构
input_nodes = 784
output_nodes = 10
layer1_nodes = 500

#定义超参数
#自动设置学习率
learning_rate_base= 0.8;
learning_decay = 0.99 ;
decay_step=100 ;

#滑动平均
moving_average__decay = 0.99
regularizer_rate = 0.01;
train_step=30000
batch_size= 100


def inference(tensor1,weight1,bias1,weight2,bias2,average_class=None):
if(average_class==None):
layer1=tf.nn.relu( tf.matmul(tensor1,weight1)+ bias1 )
return tf.matmul( layer1,weight2 ) + bias2
else:
layer1 = tf.nn.relu(tf.matmul(tensor1, average_class.average(weight1)) + average_class.average(bias1))
return tf.matmul(layer1, average_class.average(weight2) ) + average_class.average(bias2)

def get_weight(shape):
weight=tf.Variable(tf.truncated_normal(shape=shape,stddev=0.1),tf.float32)
tf.add_to_collection(‘losses‘, tf.contrib.layers.l2_regularizer(regularizer_rate)(weight))
return weight

def get_bias(shape):
return tf.Variable(tf.zeros(shape))

def train(mnist):
#定义输入输出
train_x=tf.placeholder(tf.float32,shape=[None,input_nodes],name=‘train_x‘)
train_y=tf.placeholder(tf.float32,shape=[None,output_nodes],name=‘train_y‘ )

weight1=get_weight( [input_nodes,layer1_nodes] )
bias1 =get_bias([layer1_nodes])

weight2=get_weight([layer1_nodes,output_nodes]);
bias2 =get_bias([output_nodes])


#定义学习率
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(learning_rate_base, global_step, decay_step, learning_decay,staircase=True)

#定义损失、优化器
results=inference(train_x,weight1,bias1,weight2,bias2,None)
ce= tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=train_y,labels=tf.argmax( results) ) )
loss=ce+tf.add_n( tf.get_collection(‘losses‘) )
optimizer= tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step);

#定义滑动平均
ema = tf.train.ExponentialMovingAverage(moving_average__decay, global_step);
maintain_average_op = ema.apply( tf.trainable_variables())
with tf.control_dependencies([optimizer,maintain_average_op]):
train_op=tf.no_op(name=‘train‘)

#预测准确率
average_y=inference(train_x,weight1,bias1,weight2,bias2,ema);
correction_prediction = tf.equal( tf.argmax( average_y,1 ) ,tf.argmax(train_y,1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction,tf.float32));

with tf.Session() as sess:
tf.global_variables_initializer().run()

validate_feed={train_x:mnist.validation.images,train_y:mnist.validation.labels}
test_feed ={train_x:mnist.test.images,train_y:mnist.test.labels}


#迭代训练
for i in range(train_step):
if(i%1000 == 0 ):
validate_acc=sess.run(accuracy,feed_dict=validate_feed);
print(‘After %d training steps,using aaverage model is %g ‘%(i,validate_acc))

xt,yt=mnist.train.next_batch(batch_size);
sess.run(train_op,feed_dict={ train_x :xt,train_y:yt});
test_acc=sess.run(accuracy,feed_dict=test_feed)
print(‘accuracy is %g‘%(test_acc));
def main():
mnist= input_data.read_data_sets(‘./MNIST_data‘,one_hot=True)
train(mnist);

if __name__ == ‘__main__‘:
main()

tensorflow的hello world

标签:hello   oss   end   滑动   learn   结构   col   mod   main   

原文地址:https://www.cnblogs.com/z-bear/p/10455547.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!