码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow-训练检查点tf.train.Saver

时间:2018-12-05 20:26:17      阅读:225      评论:0      收藏:0      [点我收藏+]

标签:spl   import   email   get   created   close   rop   type   python   

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Thu Sep 6 10:16:37 2018 @author: myhaspl @email:myhaspl@myhaspl.com """ import tensorflow as tf g1=tf.Graph() with g1.as_default(): with tf.name_scope("input_Variable"): my_var=tf.Variable(1,dtype=tf.float32) with tf.name_scope("global_step"): my_step=tf.Variable(0,dtype=tf.int32) with tf.name_scope("update"): varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1)) stepop=tf.assign_add(my_step,1) addop=tf.group([varop,stepop]) with tf.name_scope("summaries"): tf.summary.scalar(‘myvar‘,my_var) with tf.name_scope("global_ops"): init=tf.global_variables_initializer() merged_summaries=tf.summary.merge_all() with tf.Session(graph=g1) as sess: writer=tf.summary.FileWriter(‘sum_vars‘,sess.graph) sess.run(init) #---0 step,var,summary=sess.run([my_step,my_var,merged_summaries]) writer.add_summary(summary,global_step=step) print step,var saver=tf.train.Saver() #1-49 for i in xrange(1,50): sess.run(addop) step,var,summary=sess.run([my_step,my_var,merged_summaries]) writer.add_summary(summary,global_step=step) print step,var if i%5==0: saver.save(sess,‘./myvar-model/myvar-model‘,global_step=i) saver.save(sess,‘./myvar-model/myvar-model‘,global_step=49) writer.flush() writer.close()

38 0.0512373
39 0.04996785
40 0.048759546
41 0.04760808
42 0.04650955
43 0.045460388
44 0.04445735
45 0.04349747
46 0.042578023
47 0.041696515
48 0.040850647
49 0.04003831

保存数据流图的变量到二进制检查点文件。

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Sep  6 10:16:37 2018
@author: myhaspl
@email:myhaspl@myhaspl.com
"""

import tensorflow as tf
import os
g1=tf.Graph()

with g1.as_default(): 
    with tf.name_scope("input_Variable"):        
        my_var=tf.Variable(1,dtype=tf.float32)
    with tf.name_scope("global_step"):
        my_step=tf.Variable(0,dtype=tf.int32,trainable=False)
    with tf.name_scope("update"):
        varop=tf.assign(my_var,tf.multiply(tf.log(tf.add(my_var,1)),1))
        stepop=tf.assign_add(my_step,1)
        addop=tf.group([varop,stepop])
    with tf.name_scope("summaries"):
        tf.summary.scalar(‘myvar‘,my_var)
    with tf.name_scope("global_ops"):
        init=tf.global_variables_initializer()
        merged_summaries=tf.summary.merge_all()

with tf.Session(graph=g1) as sess:  
    writer=tf.summary.FileWriter(‘sum_vars‘,sess.graph)
    sess.run(init)

    saver=tf.train.Saver()

    #如果之前保存了检查点文件,则恢复模型后,继续
    init_step=0
    ckpt=tf.train.get_checkpoint_state(os.getcwd()+‘/myvar-model‘)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
        init_step=int(ckpt.model_checkpoint_path.rsplit(‘-‘,1)[1])
        print "读取检查点文件..."
    for i in xrange(init_step,100):
        step,var,summary=sess.run([my_step,my_var,merged_summaries])
        writer.add_summary(summary,global_step=step)
        print step,var,init_step
        if i%5==0 and i<=50:
            print "保存检查点文件"
            saver.save(sess,‘./myvar-model/myvar-model‘,global_step=i)
        sess.run(addop)

    writer.flush()
    writer.close()

上面代码跑第一次时,检查点文件被保存,跑第二次开始,检查点文件将被读取,循环次数从step=50开始。

跑第二次时

读取检查点文件...
50 0.03925755 50
保存检查点文件
51 0.038506564 50
52 0.037783686 50
53 0.03708737 50
54 0.036416177 50
55 0.035768777 50
56 0.03514393 50
...
...
...
93 0.021334965 50
94 0.02111056 50
95 0.02089082 50
96 0.0206756 50
97 0.020464761 50
98 0.020258171 50
99 0.020055704 50

tensorflow-训练检查点tf.train.Saver

标签:spl   import   email   get   created   close   rop   type   python   

原文地址:http://blog.51cto.com/13959448/2326699

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!