码迷,mamicode.com
首页 > 其他好文 > 详细

第五章 MNIST数字识别问题(二)

时间:2018-03-11 02:45:02      阅读:188      评论:0      收藏:0      [点我收藏+]

标签:平均值   exp   oat   2.0   lob   and   lap   节点   parameter   

4.1. ckpt文件保存方法

在对模型进行加载时候,需要定义出与原来的计算图结构完全相同的计算图,然后才能进行加载,并且不需要对定义出来的计算图进行初始化操作。 
这样保存下来的模型,会在其文件夹下生成三个文件,分别是: 
* .ckpt.meta文件,保存tensorflow模型的计算图结构。 
* .ckpt文件,保存计算图下所有变量的取值。 
* checkpoint文件,保存目录下所有模型文件列表。

技术分享图片
import tensorflow as tf
#保存计算两个变量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
result = v1 + v2

init_op = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "Saved_model/model.ckpt")
#加载保存了两个变量和的模型
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(result)

INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[-1.6226364]
#直接加载持久化的图。因为之前没有导出v3,所以这里会报错
saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))

with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model.ckpt")
    print sess.run(v1) 
    print sess.run(v2) 
    print sess.run(v3) 
INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
[-0.81131822]
[-0.81131822]

# 变量重命名,这样可以通过字典将模型保存时的变量名和需要加载的变量联系起来
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
saver = tf.train.Saver({"v1": v1, "v2": v2})
View Code

 

4.2.1 滑动平均类的保存

技术分享图片
import tensorflow as tf
#使用滑动平均
v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables(): print variables.name
    
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables(): print variables.name
v:0
v:0
v/ExponentialMovingAverage:0

#保存滑动平均模型
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。
    saver.save(sess, "Saved_model/model2.ckpt")
    print sess.run([v, ema.average(v)])
10.0, 0.099999905]

#加载滑动平均模型
v = tf.Variable(0, dtype=tf.float32, name="v")

# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)
INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999
View Code

 

4.2.2 variables_to_restore函数的使用样例

import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
print ema.variables_to_restore()

#等同于saver = tf.train.Saver(ema.variables_to_restore())
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)
{u‘v/ExponentialMovingAverage‘: <tf.Variable ‘v:0‘ shape=() dtype=float32_ref>}

 

4.3. pb文件保存方法

#pb文件的保存方法
import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, [‘add‘])
    with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
           f.write(output_graph_def.SerializeToString())

INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
------------------------------------------------------------------------
#加载pb文件
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename = "Saved_model/combined_model.pb"
   
    with gfile.FastGFile(model_filename, ‘rb‘) as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print sess.run(result)

[array([ 3.], dtype=float32)]

张量的名称后面有:0,表示是某个计算节点的第一个输出,而计算节点本身的名称后是没有:0的。

第五章 MNIST数字识别问题(二)

标签:平均值   exp   oat   2.0   lob   and   lap   节点   parameter   

原文地址:https://www.cnblogs.com/exciting/p/8542859.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!