码迷,mamicode.com
首页 > 其他好文 > 详细

『PyTorch』第五弹_深入理解autograd_下:Variable梯度探究

时间:2018-02-15 21:28:19      阅读:937      评论:0      收藏:0      [点我收藏+]

标签:返回值   reg   没有   函数   object   文档   定义   python   pytorch   

查看非叶节点梯度的两种方法

在反向传播过程中非叶子节点的导数计算完之后即被清空。若想查看这些变量的梯度,有两种方法:

  • 使用autograd.grad函数
  • 使用hook

autograd.gradhook方法都是很强大的工具,更详细的用法参考官方api文档,这里举例说明基础的使用。推荐使用hook方法,但是在实际使用中应尽量避免修改grad的值。

求z对y的导数

x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

# hook
# hook没有返回值,参数是函数,函数的参数是梯度值
def variable_hook(grad):
    print("hook梯度输出:\r\n",grad)

hook_handle = y.register_hook(variable_hook)         # 注册hook
z.backward(retain_graph=True)                        # 内置输出上面的hook
hook_handle.remove()                                 # 释放

print("autograd.grad输出:\r\n",t.autograd.grad(z,y)) # t.autograd.grad方法
hook梯度输出:
 Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

autograd.grad输出:
 (Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]
,)

 

多次反向传播试验

实际就是使用retain_graph参数,

# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

z.backward(retain_graph=True)
print(w.grad)
z.backward()
print(w.grad)
Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

Variable containing:
 2
 2
 2
[torch.FloatTensor of size 3]

 

如果不使用retain_graph参数,

  • 实际上效果是一样的,AccumulateGrad object仍然会积累梯度
  • 除了叶子节点之外,高层节点需要重新定义,因为原图已经传播了,需要基于原叶子建立新图,实际上第二次的z.backward()已经不是第一次的z所在的图了,这里看似简单,实际上体现了动态图的技术,静态图初始化之后会留在内存中等待feed数据,但是动态图不会,反向传播后就已经被废弃,下次要么完全重建(如下),要么反向传播之后指定不舍弃图z.backward(retain_graph=True),总之和常规的数据结构不同,图上的节点是隶属于图的属性的,TensorFlow中会一直存留,PyTorch中就会backward后直接舍弃(默认时)。
# 构件图
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

z.backward()
print(w.grad)
y = w.mul(x)  # <-----
z = y.sum()  # <-----
z.backward()
print(w.grad)
Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

Variable containing:
 2
 2
 2
[torch.FloatTensor of size 3]

 

『PyTorch』第五弹_深入理解autograd_下:Variable梯度探究

标签:返回值   reg   没有   函数   object   文档   定义   python   pytorch   

原文地址:https://www.cnblogs.com/hellcat/p/8449801.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!