码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow细节-P112-模型持久化

时间:2019-10-04 22:45:18      阅读:87      评论:0      收藏:0      [点我收藏+]

标签:rand   不用   png   session   meta   mod   gem   均值   on()   

第一个代码

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")

看看看,就是上面:注意两个方面
(1)saver = tf.train.Saver()提前设定好
(2)saver.save(sess, "Saved_model/model.ckpt")这里面有sess要注意!

第二个代码

import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(result)

这里有三个要注意的点
(1)上面定义好了模型(变量名字与第一个代码一样),Saver()里什么都没有
(2)saver.restore(sess, "Saved_model/model.ckpt")里有sess,ckpt是数据
(3)result是读取数据的结果,跟这里的变量没关系

第三个代码

import tensorflow as tf
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(v1) 
    print sess.run(v2) 
    print sess.run(v3) 

看这里,由于v3是一个变量,要输出的话需要先进行初始化(v1、v2不用)

下面,就是滑动平均了

import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():
    print(variables.name)

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
        print(variables.name)
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将v:0 ?v/ExponentialMovingAverage:0这两个变量都存下来。
    saver.save(sess, "Saved_model/model2.ckpt")
    print(sess.run([v, ema.average(v)]))

技术图片
从上面的代码和图片可以看到开始时是一个变量,后来经过maintain_averages_op = ema.apply(tf.global_variables())就多了一个影子变量,这样子,就把影子变量存好了

下面就是加载滑动平均的影子变量了

v = tf.Variable(0, dtype=tf.float32, name="v")

# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)

注意重命名

Tensorflow细节-P112-模型持久化

标签:rand   不用   png   session   meta   mod   gem   均值   on()   

原文地址:https://www.cnblogs.com/liuboblog/p/11623417.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!