前言
在上篇文章《浅谈深度学习:如何计算模型以及中间变量的显存占用大小》中我们对如何计算各种变量所占显存大小进行了一些探索。而这篇文章我们着重讲解如何利用Pytorch深度学习框架的一些特性,去查看我们当前使用的变量所占用的显存大小,以及一些优化工作。以下代码所使用的平台框架为Pytorch。
优化显存
在Pytorch中优化显存是我们处理大量数据时必要的做法,因为我们并不可能拥有无限的显存。显存是有限的,而数据是无限的,我们只有优化显存的使用量才能够最大化地利用我们的数据,实现多种多样的算法。
估测模型所占的内存
上篇文章中说过,一个模型所占的显存无非是这两种:
- 模型权重参数
- 模型所储存的中间变量
其实权重参数一般来说并不会占用很多的显存空间,主要占用显存空间的还是计算时产生的中间变量,当我们定义了一个model之后,我们可以通过以下代码简单计算出这个模型权重参数所占用的数据量:
import numpy as np
# model是我们在pytorch定义的神经网络层
# model.parameters()取出这个model所有的权重参数
para = sum([np.prod(list(p.size())) for p in model.parameters()])
假设我们有这样一个model:
Sequential(
(conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu_1): ReLU(inplace)
(conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu_2): ReLU(inplace)
(pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
然后我们得到的para
是112576
,但是我们计算出来的仅仅是权重参数的“数量”,单位是B,我们需要转化一下:
# 下面的type_size是4,因为我们的参数是float32也就是4B,4个字节
print(‘Model {} : params: {:4f}M‘.format(model._get_name(), para * type_size / 1000 / 1000))
这样就可以打印出:
Model Sequential : params: 0.450304M
但是我们之前说过一个神经网络的模型,不仅仅有权重参数还要计算中间变量的大小。怎么去计算,我们可以假设一个输入变量,然后将这个输入变量投入这个模型中,然后我们主动提取这些计算出来的中间变量:
# model是我们加载的模型
# input是实际中投入的input(Tensor)变量
# 利用clone()去复制一个input,这样不会对input造成影响
input_ = input.clone()
# 确保不需要计算梯度,因为我们的目的只是为了计算中间变量而已
input_.requires_grad_(requires_grad=False)
mods = list(model.modules())
out_sizes = []
for i in range(1, len(mods)):
m = mods[i]
# 注意这里,如果relu激活函数是inplace则不用计算
if isinstance(m, nn.ReLU):
if m.inplace:
continue
out = m(input_)
out_sizes.append(np.array(out.size()))
input_ = out
total_nums = 0
for i in range(len(out_sizes)):
s = out_sizes[i]
nums = np.prod(np.array(s))
total_nums += nums
上面得到的值是模型在运行时候产生所有的中间变量的“数量”,当然我们需要换算一下:
# 打印两种,只有 forward 和 foreward、backward的情况
print(‘Model {} : intermedite variables: {:3f} M (without backward)‘
.format(model._get_name(), total_nums * type_size / 1000 / 1000))
print(‘Model {} : intermedite variables: {:3f} M (with backward)‘
.format(model._get_name(), total_nums * type_size*2 / 1000 / 1000))
因为在backward
的时候所有的中间变量需要保存下来再来进行计算,所以我们在计算backward
的时候,计算出来的中间变量需要乘个2。
然后我们得出,上面这个模型的中间变量需要的占用的显存,很显然,中间变量占用的值比模型本身的权重值多多了。如果进行一次backward
那么需要的就更多。
Model Sequential : intermedite variables: 336.089600 M (without backward)
Model Sequential : intermedite variables: 672.179200 M (with backward)
我们总结一下之前的代码:
# 模型显存占用监测函数
# model:输入的模型
# input:实际中需要输入的Tensor变量
# type_size 默认为 4 默认类型为 float32
def modelsize(model, input, type_size=4):
para = sum([np.prod(list(p.size())) for p in model.parameters()])
print(‘Model {} : params: {:4f}M‘.format(model._get_name(), para * type_size / 1000 / 1000))
input_ = input.clone()
input_.requires_grad_(requires_grad=<span class="hljs-keyword">False</span>)
mods = list(model.modules())
out_sizes = []
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(<span class="hljs-number">1</span>, len(mods)):
m = mods[i]
<span class="hljs-keyword">if</span> isinstance(m, nn.ReLU):
<span class="hljs-keyword">if</span> m.inplace:
<span class="hljs-keyword">continue</span>
out = m(input_)
out_sizes.append(np.array(out.size()))
input_ = out
total_nums = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(len(out_sizes)):
s = out_sizes[i]
nums = np.prod(np.array(s))
total_nums += nums
print(<span class="hljs-string">‘Model {} : intermedite variables: {:3f} M (without backward)‘</span>
.format(model._get_name(), total_nums * type_size / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>))
print(<span class="hljs-string">‘Model {} : intermedite variables: {:3f} M (with backward)‘</span>
.format(model._get_name(), total_nums * type_size*<span class="hljs-number">2</span> / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>))
当然我们计算出来的占用显存值仅仅是做参考作用,因为Pytorch
在运行的时候需要额外的显存值开销,所以实际的显存会比我们计算的稍微大一些。
关于inplace=False
我们都知道激活函数Relu()
有一个默认参数inplace
,默认设置为False
,当设置为True
时,我们在通过relu()
计算时的得到的新值不会占用新的空间而是直接覆盖原来的值,这也就是为什么当inplace参数设置为True时可以节省一部分内存的缘故。
牺牲计算速度减少显存使用量
在Pytorch-0.4.0
出来了一个新的功能,可以将一个计算过程分成两半,也就是如果一个模型需要占用的显存太大了,我们就可以先计算一半,保存后一半需要的中间结果,然后再计算后一半。
也就是说,新的checkpoint
允许我们只存储反向传播所需要的部分内容。如果当中缺少一个输出(为了节省内存而导致的),checkpoint
将会从最近的检查点重新计算中间输出,以便减少内存使用(当然计算时间增加了):
# 输入
input = torch.rand(1, 10)
# 假设我们有一个非常深的网络
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)
output = model(input)
上面的模型需要占用很多的内存,因为计算中会产生很多的中间变量。为此checkpoint
就可以帮助我们来节省内存的占用了。
# 首先设置输入的input=>requires_grad=True
# 如果不设置可能会导致得到的gradient为0
input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]
# 定义要计算的层函数,可以看到我们定义了两个
# 一个计算前500个层,另一个计算后500个层
def run_first_half(*args):
x = args[0]
for layer in layers[:500]:
x = layer(x)
return x
def run_second_half(*args):
x = args[0]
for layer in layers[500:-1]:
x = layer(x)
return x
# 我们引入新加的checkpoint
from torch.utils.checkpoint import checkpoint
x = checkpoint(run_first_half, input)
x = checkpoint(run_second_half, x)
# 最后一层单独调出来执行
x = layers-1
x.sum.backward() # 这样就可以了
对于Sequential-model
来说,因为Sequential()
中可以包含很多的block
,所以官方提供了另一个功能包:
input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)
from torch.utils.checkpoint import checkpoint_sequential
# 分成两个部分
num_segments = 2
x = checkpoint_sequential(model, num_segments, input)
x.sum().backward() # 这样就可以了
跟踪显存使用情况
显存的使用情况,在编写程序中我们可能无法精确计算,但是我们可以通过pynvml这个Nvidia的Python环境库和Python的垃圾回收工具,可以实时地打印我们使用的显存以及哪些Tensor使用了我们的显存。
类似于下面的报告:
# 08-Jun-18-17:56:51-gpu_mem_prof
At __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">39</span> Total Used Memory:<span class="hljs-number">399.4</span> Mb
At __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> Total Used Memory:<span class="hljs-number">992.5</span> Mb
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> (<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">1.82</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.Tensor‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">5.46</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.Tensor‘</span>&<span class="hljs-keyword">gt</span>;
At __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> Total Used Memory:<span class="hljs-number">1088.5</span> Mb
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">14</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">28</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">56</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">2.25</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">4.5</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">9.0</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">5.46</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.Tensor‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">3</span>,) <span class="hljs-number">1.14</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.Tensor‘</span>&<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &<span class="hljs-keyword">lt</span>;module&<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">1.12</span> M &<span class="hljs-keyword">lt</span>;class <span class="hljs-string">‘torch.nn.parameter.Parameter‘</span>&<span class="hljs-keyword">gt</span>;
...</code></pre>
以下是相关的代码,目前代码依然有些地方需要修改,等修改完善好我会将完整代码以及使用说明放到github上:https://github.com/Oldpan/Pytorch-Memory-Utils
请大家多多留意。
import datetime
import linecache
import os
import gc
import pynvml
import torch
import numpy as np
print_tensor_sizes = True
last_tensor_sizes = set()
gpu_profile_fn = f‘{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_prof.txt‘
# if ‘GPU_DEBUG‘ in os.environ:
# print(‘profiling gpu usage to ‘, gpu_profile_fn)
lineno = None
func_name = None
filename = None
module_name = None
# fram = inspect.currentframe()
# func_name = fram.f_code.co_name
# filename = fram.f_globals["__file__"]
# ss = os.path.dirname(os.path.abspath(filename))
# module_name = fram.f_globals["__name__"]
def gpu_profile(frame, event):
# it is _about to_ execute (!)
global last_tensor_sizes
global lineno, func_name, filename, module_name
if event == ‘line‘:
try:
# about _previous_ line (!)
if lineno is not None:
pynvml.nvmlInit()
# handle = pynvml.nvmlDeviceGetHandleByIndex(int(os.environ[‘GPU_DEBUG‘]))
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
line = linecache.getline(filename, lineno)
where_str = module_name+‘ ‘+func_name+‘:‘+‘ line ‘+str(lineno)
with open(gpu_profile_fn, ‘a+‘) as f:
f.write(f"At {where_str:<50}"
f"Total Used Memory:{meminfo.used/1024**2:<7.1f}Mb\n")
if print_tensor_sizes is True:
for tensor in get_tensors():
if not hasattr(tensor, ‘dbg_alloc_where‘):
tensor.dbg_alloc_where = where_str
new_tensor_sizes = {(type(x), tuple(x.size()), np.prod(np.array(x.size()))*4/1024**2,
x.dbg_alloc_where) for x in get_tensors()}
for t, s, m, loc in new_tensor_sizes - last_tensor_sizes:
f.write(f‘+ {loc:<50} {str(s):<20} {str(m)[:4]} M {str(t):<10}\n‘)
for t, s, m, loc in last_tensor_sizes - new_tensor_sizes:
f.write(f‘- {loc:<50} {str(s):<20} {str(m)[:4]} M {str(t):<10}\n‘)
last_tensor_sizes = new_tensor_sizes
pynvml.nvmlShutdown()
# save details about line _to be_ executed
lineno = None
func_name = frame.f_code.co_name
filename = frame.f_globals["__file__"]
if (filename.endswith(".pyc") or
filename.endswith(".pyo")):
filename = filename[:-1]
module_name = frame.f_globals["__name__"]
lineno = frame.f_lineno
return gpu_profile
except Exception as e:
print(‘A exception occured: {}‘.format(e))
return gpu_profile
def get_tensors():
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
tensor = obj
else:
continue
if tensor.is_cuda:
yield tensor
except Exception as e:
print(‘A exception occured: {}‘.format(e))
需要注意的是,linecache中的getlines只能读取缓冲过的文件,如果这个文件没有运行过则返回无效值。Python 的垃圾收集机制会在变量没有应引用的时候立马进行回收,但是为什么模型中计算的中间变量在执行结束后还会存在呢。既然都没有引用了为什么还会占用空间?
一种可能的情况是这些引用不在Python代码中,而是在神经网络层的运行中为了backward被保存为gradient,这些引用都在计算图中,我们在程序中是无法看到的:
后记
实际中我们会有些只使用一次的模型,为了节省显存,我们需要一边计算一遍清除中间变量,使用del进行操作。限于篇幅这里不进行讲解,下一篇会进行说明。
原文地址:如何在Pytorch中精细化利用显存
<br>
原创文章,转载请注明 :<a href="https://ptorch.com/news/181.html" target="_blank">如何在Pytorch中精细化利用显存以及提高Pytorch显存利用率 - pytorch中文网</a><br>
原文出处: https://ptorch.com/news/181.html<br>
问题交流群 :168117787
</div>