标签:概率 系统 add data 感受 哪些 公众 batch 图片
算是动态图的一个坑吧。记录loss信息的时候直接使用了输出的Variable。
应该不止我经历过这个吧...
久久不用又会不小心掉到这个坑里去...
for data, label in trainloader:
......
out = model(data)
loss = criterion(out, label)
loss_sum += loss # <--- 这里
......
运行着就发现显存炸了
观察了一下发现随着每个batch显存消耗在不断增大..
参考了别人的代码发现那句loss一般是这样写 /(ㄒoㄒ)/~~
loss_sum += loss.data[0]
这是因为输出的loss的数据类型是Variable。
而PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。
如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大~那么消耗的显存也就越来越大~~
总之使用Variable的数据时候要非常小心。不是必要的话尽量使用Tensor来进行计算...
包括数据的输入时候,如果“过早”把数据丢到Variable里面去,那么可能也会被系统视为网络的一部分。所以,要投入的时候再把数据丢到Variable里面去吧~
题外话
想更多感受动态图的话,可以通过Variable的grad_fun来观察到该Variable是通过什么运算得到的(前提是前面的Variable的required_grad置为True)。
大概是这样
>> >> z = x + y
>> z.grad_fn
out:
<AddBackward1 at 0x107286240>
【深度学习实战】pytorch中如何处理RNN输入变长序列padding
【机器学习基本理论】详解最大后验概率估计(MAP)的理解
【区块链】区块链最通俗入门教程
欢迎关注公众号学习交流~
标签:概率 系统 add data 感受 哪些 公众 batch 图片
原文地址:https://blog.51cto.com/15009309/2554205