码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch模型参数

时间:2019-08-05 20:25:45      阅读:159      评论:0      收藏:0      [点我收藏+]

标签:参数   style   ini   __init__   span   dict   parameter   class   color   

1、torch.nn.state_dict():

返回一个字典,保存着module的所有状态(state)。

parameters和persistent_buffers都会包含在字典中,字典的key就是parameter和buffer的names。

例子:

import torch
from torch.autograd import Variable
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv2 = nn.Linear(1, 2)
        self.vari = Variable(torch.rand([1]))
        self.par = nn.Parameter(torch.rand([1]))
        self.register_buffer("buffer", torch.randn([2,3]))

model = Model()
print(model.state_dict().keys())
odict_keys([par, buffer, conv2.weight, conv2.bias])

 

字典迭代形式{<class ‘str‘>:<class ‘torch.Tensor‘>, ... }

pytorch模型参数

标签:参数   style   ini   __init__   span   dict   parameter   class   color   

原文地址:https://www.cnblogs.com/lucifer1997/p/11305150.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!