码迷,mamicode.com
首页 > 其他好文 > 详细

TensorFlow 存储与读取

时间:2017-12-12 21:49:05      阅读:180      评论:0      收藏:0      [点我收藏+]

标签:font   post   需要   一点   sse   var   dev   ret   文件   

之前通过CNN进行的MNIST训练识别成功率已经很高了,不过每次运行都需要消耗很多的时间。在实际使用的时候,每次都要选经过训练后在进行识别那就太不方便了。

所以我们学习一下如何将训练习得的参数保存起来,然后在需要用的时候直接使用这些参数进行快速的识别。

本章节代码来自《Tensorflow 实战Google深度学习框架》5.5 TensorFlow 最佳实践样例程序  针对书中的代码做了一点点的调整。

 

mnist_inference.py:

#coding=utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500

def get_weight_variable(shape, regularizer):
    weights = tf.get_variable("weights", shape, initializer = tf.truncated_normal_initializer(stddev=0.1))
    if regularizer != None:
        tf.add_to_collection(losses, regularizer(weights))
    return weights

def inference(input_tensor, regularizer):
    with tf.variable_scope(layer1):
        weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
        biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)

    with tf.variable_scope(layer2):
        weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
        biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1, weights) + biases

    return layer2

这里是向前传播的方法文件。这个方法在训练和测试的过程都需要用到,将它抽离出来既能使用起来更加方便,也能保证训练和测试时使用的方法保持一致。

get_variable

 weights = tf.get_variable("weights", shape, initializer = tf.truncated_normal_initializer(stddev=0.1))

源代码第十行使用get_variable函数获取变量。

在训练网络是会创建这些变量;

在测试时会通过训练时保存的模型加载这些变量的值。

 

(未完待续。。。。)

TensorFlow 存储与读取

标签:font   post   需要   一点   sse   var   dev   ret   文件   

原文地址:http://www.cnblogs.com/guolaomao/p/8028600.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!