标签:blog 文件 size store with 管理 shape dom 技术分享
我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。
示例代码:
import tensorflow as tf import numpy as np from six.moves import xrange x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 2 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) #isTrain = True isTrain = False train_steps = 100 checkpoint_steps = 50 checkpoint_dir = ‘test/‘ saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + ‘model.ckpt‘, global_step=i + 1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b)) y_result = sess.run(y_predict, feed_dict={x: np.reshape(4, (1, 1))}) print(y_result)
标签:blog 文件 size store with 管理 shape dom 技术分享
原文地址:http://www.cnblogs.com/scarecrow-blog/p/7794388.html