码迷,mamicode.com
首页 > 其他好文 > 详细

Tensorflow Learning1 模型的保存和恢复

时间:2019-09-09 22:44:53      阅读:131      评论:0      收藏:0      [点我收藏+]

标签:pre   state   ret   目的   ble   color   指定   var   固定   


CKPT->pb

Demo

解析

tensor name 和 node name 的区别

Pb 的恢复



CKPT->pb

tensorflow的模型保存有两种形式:

1. ckpt:可以恢复图和变量,继续做训练

2. pb : 将图序列化,变量成为固定的值,,只可以做inference;不能继续训练


Demo


  1 def freeze_graph(input_checkpoint,output_graph):
  2 
  3     ‘‘‘
  4     :param input_checkpoint:
  5     :param output_graph: PB模型保存路径
  6     :return
  7       void
  8     ‘‘‘
  9 
 10     # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 11     # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 12 
 13     # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 14     output_node_names = "InceptionV3/Logits/SpatialSqueeze" # 如果是多个输出节点,使用 ‘,’号隔开
 15 
 16     ############################     Step1: 从ckpt中恢复图:     #############################################
 17     saver = tf.train.import_meta_graph(input_checkpoint + ‘.meta‘, clear_devices=True)
 18     graph = tf.get_default_graph() # 获得默认的图, 可以省略
 19     input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图,可以省略
 20 
 21     with tf.Session() as sess: # 会使用默认的图 作为当前的图
 22         saver.restore(sess, input_checkpoint) #恢复图并得到数据
 23 
 24         ########################     Step2: 创建持久化对象,指定sess,图、以及输出的序列化节点信息    ##############
 25         output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
 26             sess=sess,
 27             input_graph_def=input_graph_def,# 等于:sess.graph_def
 28             output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 29         #########################    Step3: 模型持久化   #######################################################
 30         with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 31             f.write(output_graph_def.SerializeToString()) #序列化输出
 32         print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 33         # for op in graph.get_operations():
 34 
 35         #     print(op.name, op.values())
 36 
 37 
 38 ########################### 调用方式 ################################
 39 # 输入ckpt模型路径
 40 input_checkpoint=‘models/model.ckpt-10000‘
 41 # 输出pb模型的路径
 42 out_pb_path="models/pb/frozen_model.pb"
 43 # 调用freeze_graph将ckpt转为pb
 44 freeze_graph(input_checkpoint,out_pb_path)

解析

函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。

freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

在保存pb的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称;

tensor name 和 node name 的区别

node name 是 图 的节点,里面包含了很多操作和tensor

tensor 是 node 里面的一个组成部分;

以input 为例,“input:0”是张量的名称,而"input"表示的是节点的名称

PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"input:0"


Tensorflow Learning1 模型的保存和恢复

标签:pre   state   ret   目的   ble   color   指定   var   固定   

原文地址:https://www.cnblogs.com/greentomlee/p/11494383.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!