码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch之 sava_reload_model

时间:2019-10-26 15:19:35      阅读:83      评论:0      收藏:0      [点我收藏+]

标签:class   ble   sgd   mat   code   step   lin   stat   auto   

 1 import torch
 2 import matplotlib.pyplot as plt
 3 
 4 # torch.manual_seed(1)    # reproducible
 5 
 6 # fake data
 7 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
 8 y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 9 
10 # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
11 # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
12 
13 
14 def save():
15     # save net1
16     net1 = torch.nn.Sequential(
17         torch.nn.Linear(1, 10),
18         torch.nn.ReLU(),
19         torch.nn.Linear(10, 1)
20     )
21     optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
22     loss_func = torch.nn.MSELoss()
23 
24     for t in range(100):
25         prediction = net1(x)
26         loss = loss_func(prediction, y)
27         optimizer.zero_grad()
28         loss.backward()
29         optimizer.step()
30 
31     # plot result
32     plt.figure(1, figsize=(10, 3))
33     plt.subplot(131)
34     plt.title(Net1)
35     plt.scatter(x.data.numpy(), y.data.numpy())
36     plt.plot(x.data.numpy(), prediction.data.numpy(), r-, lw=5)
37 
38     # 2 ways to save the net
39     torch.save(net1, net.pkl)  # save entire net
40     torch.save(net1.state_dict(), net_params.pkl)   # save only the parameters
41 
42 
43 def restore_net():
44     # restore entire net1 to net2
45     net2 = torch.load(net.pkl)
46     prediction = net2(x)
47 
48     # plot result
49     plt.subplot(132)
50     plt.title(Net2)
51     plt.scatter(x.data.numpy(), y.data.numpy())
52     plt.plot(x.data.numpy(), prediction.data.numpy(), r-, lw=5)
53 
54 
55 def restore_params():
56     # restore only the parameters in net1 to net3
57     net3 = torch.nn.Sequential(
58         torch.nn.Linear(1, 10),
59         torch.nn.ReLU(),
60         torch.nn.Linear(10, 1)
61     )
62 
63     # copy net1‘s parameters into net3
64     net3.load_state_dict(torch.load(net_params.pkl))
65     prediction = net3(x)
66 
67     # plot result
68     plt.subplot(133)
69     plt.title(Net3)
70     plt.scatter(x.data.numpy(), y.data.numpy())
71     plt.plot(x.data.numpy(), prediction.data.numpy(), r-, lw=5)
72     plt.show()
73 
74 # save net1
75 save()
76 
77 # restore entire net (may slow)
78 restore_net()
79 
80 # restore only the net parameters
81 restore_params()

 

pytorch之 sava_reload_model

标签:class   ble   sgd   mat   code   step   lin   stat   auto   

原文地址:https://www.cnblogs.com/dhName/p/11742959.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!