码迷,mamicode.com
首页 > 其他好文 > 详细

PyTorch模型加载与保存的最佳实践

时间:2020-05-18 21:07:26      阅读:323      评论:0      收藏:0      [点我收藏+]

标签:drop   固定   pat   hang   bsp   模型   loss   check   object   

一般来说PyTorch有两种保存和读取模型参数的方法。但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题。

传统方案:

第一种方案是保存整个模型:

torch.save(model_object, model.pth)

第二种方法是保存模型网络参数:

torch.save(model_object.state_dict(), params.pth)

加载的时候分别这样加载:

model = torch.load(model.pth)

以及:

model_object.load_state_dict(torch.load(params.pth))

改进的方案

注意到这个方案是因为模型在加载之后,loss会飙升之后再慢慢降回来。查阅有关分析之后,判定是优化器optimizer的问题。

如果模型的保存是为了恢复训练状态,那么可以考虑同时保存优化器optimizer的参数:

state = {
    epoch: epoch,
    net: model.state_dict(),
    optimizer: optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

然后这样加载:

checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint[net])
optimizer.load_state_dict(checkpoint[optimizer])
start_epoch =  checkpoint[epoch] + 1

如果模型的保存是为了方便以后进行validation和test,可以在加载完之后制定model.eval()固定dropout和BN层

 

https://ldzhangyx.github.io/2018/11/19/pytorch-1119/

PyTorch模型加载与保存的最佳实践

标签:drop   固定   pat   hang   bsp   模型   loss   check   object   

原文地址:https://www.cnblogs.com/jiangkejie/p/12912665.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!