码迷,mamicode.com
首页 > 其他好文 > 详细

模型的加载和保存

时间:2020-06-25 21:25:26      阅读:58      评论:0      收藏:0      [点我收藏+]

标签:you   return   pat   文件名   ade   module   模型   变化   makefile   

pytorch三种模型的加载保存操作

方法1 : PATH表示保存模型的路径和文件名

torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)
# train your medel...

# save model
FILE = "model.pth"
torch.save(model, FILE)

# load model
model = torch.load(FILE)

# 防止模型参数发生变化
model.eval()
for param in model.parameters():
    print(param)

方法二:

保存模型时使用模型的state_dict()方法,加载模型前先实例化一个模型,然后调用load_state_dict()方法

torch.save(model.state_dict(), PATH)
# model must be created again with parameters
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)
# train your medel...


for param in model.parameters():
    print(param)

# save model
FILE = "model.pth"
torch.save(model.state_dict(), FILE)

loaded_model = Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE))

# 防止模型参数发生变化
loaded_model.eval()
for param in loaded_model.parameters():
    print(param)

方法三:

定义一个字典,保存多个参数到模型

class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)
# train your medel...

# print(model.state_dict())


learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# print(optimizer.state_dict())


checkpoint = {
    "epoch": 90,
    "model_state": model.state_dict(),
    "optim_state": optimizer.state_dict()
}
# 保存三种数据到模型
torch.save(checkpoint, "checkpoint.pth")

# 加载模型
loaded_checkpoint = torch.load("checkpoint.pth")
# 载入epcho数据
epoch = loaded_checkpoint[‘epoch‘]
print(epoch)

# 定义模型和优化器
model = Model(n_input_features=6)
optimizer = torch.optim.SGD(model.parameters(), lr=0)


# 将保存的模型数据载入到模型和优化器中
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optim_state"])

 推荐:什么是顺时针,你看到的是顺时针还是逆时针

模型的加载和保存

标签:you   return   pat   文件名   ade   module   模型   变化   makefile   

原文地址:https://www.cnblogs.com/1994july/p/13192735.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!