码迷,mamicode.com
首页 > 其他好文 > 详细

Variable和get_variable的用法以及区别

时间:2018-09-28 22:13:35      阅读:320      评论:0      收藏:0      [点我收藏+]

标签:结果   class   constant   init   不同的   var   span   也会   get   

在tensorflow中,可以使用tf.Variable来创建一个变量,也可以使用tf.get_variable来创建一个变量,但是在一个模型需要使用其他模型的变量时,tf.get_variable就派上大用场了。

先分别介绍两个函数的用法:

 1 import tensorflow as tf
 2 var1 = tf.Variable(1.0,name=firstvar)
 3 print(var1:,var1.name)
 4 var1 = tf.Variable(2.0,name=firstvar)
 5 print(var1:,var1.name)
 6 var2 = tf.Variable(3.0)
 7 print(var2:,var2.name)
 8 var2 = tf.Variable(4.0)
 9 print(var2:,var2.name)
10 get_var1 = tf.get_variable(name=firstvar,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.3))
11 print(get_var1:,get_var1.name)
12 get_var1 = tf.get_variable(name=firstvar1,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.4))
13 print(get_var1:,get_var1.name)
14 
15 with tf.Session() as sess:
16     sess.run(tf.global_variables_initializer())
17     print(var1=,var1.eval())
18     print(var2=,var2.eval())
19     print(get_var1=,get_var1.eval())

结果如下:

技术分享图片

我们来分析一下代码,tf.Varibale是以定义的变量名称为唯一标识的,如var1,var2,所以可以重复地创建name=‘firstvar‘的变量,但是tensorflow会给它们按顺序取后缀,如firstvar_1:0,firstval_2:0,...,如果没有制定名字,系统会自动加上一个名字Variable:0。而且由于tf.Varibale是以定义的变量名称为唯一标识的,所以当第二次命名同一个变量名时,第一个变量就会被覆盖,所以var1由1.0变成2.0。

对于tf.get_variable,它是以指定的name属性为唯一标识,而不是定义的变量名称,所以不能同时定义两个变量name是相同的,例如下面这种就会报错:

1 get_var1 = tf.get_variable(name=a,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.3))
2 print(get_var1:,get_var1.name)
3 get_var2 = tf.get_variable(name=a,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.4))
4 print(get_var1:,get_var1.name)

这样就会报错了。如果我们想声明两次相同name的变量,这时variable_scope就派上用场了,可以使用variable_scope将它们分开:

1 import tensorflow as tf
2 with tf.variable_scope(test1):
3     get_var1 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
4 with tf.variable_scope(test2):
5     get_var2 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
6 print(get_var1:,get_var1.name)
7 print(get_var2:,get_var2.name)

这样就不会报错了,variable_scope相当于声明了作用域,这样在不同的作用域存在相同的变量就不会冲突了,结果如下:

技术分享图片

当然,scope还支持嵌套:

1 import tensorflow as tf
2 with tf.variable_scope(test1,):
3     get_var1 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
4     with tf.variable_scope(test2,):
5         get_var2 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
6 print(get_var1:,get_var1.name)
7 print(get_var2:,get_var2.name)

输出结果为:

技术分享图片

怎么样?可以对照上面的结果体会一下不同。那么如何通过get_variable来实现变量共享呢?这就要用到variable_scope里的一个属性:reuse,顾名思义嘛,当把reuse设置成True时就可以了,它表示使用已经定义过的变量,这是get_variable就不会再创建新的变量,而是去找与name相同的变量:

import tensorflow as tf
with tf.variable_scope(test1,):
    get_var1 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
    with tf.variable_scope(test2,):
        get_var2 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
print(get_var1:,get_var1.name)
print(get_var2:,get_var2.name)
with tf.variable_scope(test1,reuse=True):
    get_var3 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
    with tf.variable_scope(test2,):
        get_var4 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
print(get_var3:,get_var3.name)
print(get_var4:,get_var4.name)

输出结果如下:

技术分享图片

当然前面说过,reuse=True是使用前面已经创建过的变量,如果仅仅只有从第八行到最后的代码,也会报错的,如果还是想这么做,就需要把reuse属性设置成tf.AUTO_REUSE

1 import tensorflow as tf
2 with tf.variable_scope(test1,reuse=tf.AUTO_REUSE):
3     get_var3 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
4     with tf.variable_scope(test2,):
5         get_var4 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
6 print(get_var3:,get_var3.name)
7 print(get_var4:,get_var4.name)

此时就不会报错,tf.AUTO_REUSE可以实现第一次调用variable_scope时,传入的reuse值为False,再次调用时,传入reuse的值就会自动变为True。

Variable和get_variable的用法以及区别

标签:结果   class   constant   init   不同的   var   span   也会   get   

原文地址:https://www.cnblogs.com/wf-ml/p/9721027.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!