目前tf只能保存模型中的variable变量,整个模型还不能保存,版本1.x
保存模型代码
import tensorflow as tf import numpy as np # Save to file # remember to define the same dtype and shape when restore v1 = tf.Variable(tf.constant(1.0,shape=[1]), name=‘v1‘) v2 = tf.Variable(tf.constant(2.0,shape=[1]), name=‘v2‘) result=v1+v2 # tf.initialize_all_variables() no long valid from # 2017-03-02 if using tensorflow >= 0.12 if int((tf.__version__).split(‘.‘)[1]) < 12 and int((tf.__version__).split(‘.‘)[0]) < 1: init = tf.initialize_all_variables() else: init = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init) save_path = saver.save(sess,"save_model/save_pp.ckpt") print("Save to path: ", save_path)
文件结构如下
还原模型代码
################################################ # restore variables # redefine the same shape and same type for your variables v1 = tf.Variable(tf.constant(1.0,shape=[1]), name=‘v1‘) v2 = tf.Variable(tf.constant(2.0,shape=[1]), name=‘v2‘) result=v1+v2 # not need init step saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "./save_model/save_pp.ckpt") print("v:", sess.run(v1)) print("result:", sess.run(result))
报错信息
未解决