码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow模型的 暂存 恢复 微调 保存 加载

时间:2019-10-03 22:00:37      阅读:138      评论:0      收藏:0      [点我收藏+]

标签:href   lse   log   csdn   tail   run   constant   tps   ems   

  • 暂存模型(*.index为参数名称,*.meta为模型图,*.data*为参数)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))
saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

sess.close()
  • 暂存模型(同一模型多次保存可以不保存模型图节省时间)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))
saver.save(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))
time.sleep(5)
saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL1_NAME), write_meta_graph=False)
time.sleep(5)
saver.save(sess, "%s/%s1" % (MODEL_DIR, MODEL2_NAME), write_meta_graph=False)

sess.close()
  • 恢复模型(手动生成网络则不需要*.meta文件)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

print(sess.run([weights]))

sess.close()
  • 恢复模型(从*.meta文件生成网络)
tf.reset_default_graph()

saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))

sess.close()
  • 恢复模型(可以在一个文件夹下保存多次模型,checkpoint文件会自动记录所有模型名称和最后一次记录模型名称)
tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
ckpt = tf.train.get_checkpoint_state(MODEL_DIR)
saver.restore(sess, ckpt.model_checkpoint_path)

print(sess.run([weights]))

sess.close()
  • 微调模型(恢复之前训练模型的部分参数,加上新参数,继续训练)
def get_variables_available_in_checkpoint(variables, checkpoint_path, include_global_step=True):
    ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
    ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
    if not include_global_step:
        ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
    vars_in_ckpt = {}
    for variable_name, variable in sorted(variables.items()):
        if variable_name in ckpt_vars_to_shape_map:
            if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
                vars_in_ckpt[variable_name] = variable
    return vars_in_ckpt

tf.reset_default_graph()

weights = tf.Variable(tf.random_normal([10, 10], stddev=0.1), name="weights")
biases = tf.Variable(0, name="biases")
other_weights = tf.Variable(tf.zeros([10, 10]))

variables_to_init = tf.global_variables()
variables_to_init_dict = {var.op.name: var for var in variables_to_init}
available_var_map = get_variables_available_in_checkpoint(variables_to_init_dict,
    "%s/%s" % (MODEL_DIR, MODEL_NAME), include_global_step=False)
tf.train.init_from_checkpoint("%s/%s" % (MODEL_DIR, MODEL_NAME), available_var_map)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run([weights]))

sess.close()
  • 保存模型(二进制模型)
from tensorflow.python.framework.graph_util import convert_variables_to_constants

tf.reset_default_graph()

saver=tf.train.import_meta_graph("%s/%s.meta" % (MODEL_DIR, MODEL_NAME))
sess = tf.Session()
saver.restore(sess, "%s/%s" % (MODEL_DIR, MODEL_NAME))

graph_out = convert_variables_to_constants(sess, sess.graph_def, output_node_names=[weights])
with tf.gfile.GFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME), "wb") as output:
    output.write(graph_out.SerializeToString())

sess.close()
  • 加载模型(二进制模型)
tf.reset_default_graph()

sess = tf.Session()
with tf.gfile.FastGFile("%s/%s" % (MODEL_DIR, PB_MODEL_NAME),rb) as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def,name=‘‘)
sess.run(tf.global_variables_initializer())

print(sess.run([tf.get_default_graph().get_tensor_by_name("weights:0")]))

sess.close()

 

 

 

参考文献:

https://blog.csdn.net/loveliuzz/article/details/81661875

https://www.cnblogs.com/bbird/p/9951943.html

https://blog.csdn.net/gzj_1101/article/details/80299610

 

Tensorflow模型的 暂存 恢复 微调 保存 加载

标签:href   lse   log   csdn   tail   run   constant   tps   ems   

原文地址:https://www.cnblogs.com/jhc888007/p/11620821.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!