码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch显存越来越多的一个自己没注意的原因

时间:2019-01-07 20:55:39      阅读:1630      评论:0      收藏:0      [点我收藏+]

标签:lan   highlight   观察   blog   产生   就是   lang   csdn   数据类型   

optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss

参考:https://blog.csdn.net/qq_27292549/article/details/80250031

我和博主犯了一毛一样的低级错误。。。。

 

下面是原博解释:

运行着就发现显存炸了

观察了一下发现随着每个batch显存消耗在不断增大..

参考了别人的代码发现那句loss一般是这样写 

loss_sum += loss.data[0]

这是因为输出的loss的数据类型是Variable。

而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。

如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~

总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算...

 

补充:

用Tensor计算也是有坑的,要写成:

 train_loss += loss.item()

不然显存还是会炸。。。。。

 

pytorch显存越来越多的一个自己没注意的原因

标签:lan   highlight   观察   blog   产生   就是   lang   csdn   数据类型   

原文地址:https://www.cnblogs.com/Charlene-HRI/p/10234656.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!