码迷,mamicode.com
首页 > 其他好文 > 详细

[深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

时间:2019-10-20 15:51:01      阅读:270      评论:0      收藏:0      [点我收藏+]

标签:参考   state   网络模型   strong   ORC   tps   ges   方法   ima   

[深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载预测模型问题

上一篇实践学习中,遇到了在多/单个GPU、GPU与CPU的不同环境下训练保存、加载使用使用模型的问题,如果保存、加载的上述三类环境不同,加载时会出错。就去研究了一下,做了实验,得出以下结论:

多/单GPU训练保存模型参数、CPU加载使用模型

#保存
PATH = 'cifar_net.pth'
torch.save(net.module.state_dict(), PATH)

#加载
net = Net()
net.load_state_dict(torch.load(PATH))

多GPU训练模型、单GPU加载使用模型

#保存
PATH = 'cifar_net.pth'
torch.save(net.state_dict(), PATH)

#加载
net = Net()
net = nn.DataParallel(net)  #保存多GPU的,在加载时需要把网络也转成DataParallel的
net.to(device)  #放到GPU上
net.load_state_dict(torch.load(PATH))

# 然后测试数据也需要放到GPU上
images, labels = images.to(device), labels.to(device)

多GPU训练保存模型参数、多GPU加载使用模型

#保存
PATH = 'cifar_net.pth'
torch.save(net.state_dict(), PATH)

#加载
net = Net()
net = nn.DataParallel(net)  #保存多GPU的,在加载时需要把网络也转成DataParallel的
net.to(device)  #放到GPU上
net.load_state_dict(torch.load(PATH))

# 然后测试数据也需要放到GPU上
images, labels = images.to(device), labels.to(device)

可以看到,单GPU和多GPU加载数据的方法其实是一样的,经运行验证,只要按上述代码写,有多个GPU就调用多个,只有一个就调用一个。

另外,保存、加载网络模型还有三种不同的做法

1.保存整个网络模型
2.只保存模型参数(我们用的就是这种)
3.自定义保存

详细方法,请参考:https://blog.csdn.net/Code_Mart/article/details/88254444

[深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

标签:参考   state   网络模型   strong   ORC   tps   ges   方法   ima   

原文地址:https://www.cnblogs.com/importGPX/p/11707642.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!