码迷,mamicode.com
首页 > 其他好文 > 详细

深度学习之构造模型,访问模型参数——2020.3.11

时间:2020-03-11 12:27:29      阅读:78      评论:0      收藏:0      [点我收藏+]

标签:info   back   red   包含   简便   super   val   sum   fun   

今天主要学习了利用torch中的nn模块定义Module类,下面的代码包含对于模型类的构建以及参数访问,简便的可以使用‘net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())’构建模型,默认进行初始化。

# 3.1 构造模型
import torch 
from torch import nn #module类是nn模块里提供的一个模型构造类

# 定义MLP类
class MLP(nn.Module):
    def __init__(self, **kwargs):
        super(MLP, self).__init__(**kwargs) #重载MLP类
        self.hidden = nn.Linear(784, 256)
        self.act = nn.ReLU()
        self.output = nn.Linear(256, 10)
    
    # 定义前向计算,反向传播函数可通过生成反向传播所需的backward函数
    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)
# 初始化net并传入输入数据x,做前向计算
X = torch.rand(2, 784)
net = MLP()
net(X)
![](https://img2020.cnblogs.com/blog/1190122/202003/1190122-20200311120753688-1814624801.png)
# 4.12 module 的子类
class MySquential(nn.Module):
    from collections import OrderedDict
    def __init__(self, *args):
        super(MySquential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key,module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
    
    def forward(self, input):
        for module in self._modules.values():
            input = module(input)
        return input
net = MySquential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10),)
print(net)
net(X)

输出结果

技术图片

# ModuleLise 类
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10))
print(net[0]) #使用Listd的索引访问
print(net)

输出结果

技术图片

# ModuleDict类
net = nn.ModuleDict({'linear' : nn.Linear(784, 256), 'act' : nn.ReLU(),})
net['output'] = nn.Linear(256, 10)
print(net['linear']) # 访问
print(net.output)
print(net)
# 构造复杂模型
class FancyMLP(nn.Module):
    def __init__(self, **kwargs):
        super(FancyMLP, self).__init__(**kwargs)
        
        self.rand_weight = torch.rand((20, 20),requires_grad=False)
        self.linear = nn.Linear(20, 20)
        
    def forward(self, x):
        x = self.linear(x)
        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)
        
        x = self.linear(x)
        while x.norm().item() > 1:
            x /= 2
        if x.norm().item() < 0.0:
            x *= 10
        return x.sum()
        
X = torch.rand(2, 20)
net = FancyMLP()
print(net)
net(X)

输出结果

技术图片

# 嵌套调用FancyMLP和Sequential类
class NestMLP(nn.Module):
    def __init__(self, **kwargs):
        super(NestMLP, self).__init__(**kwargs)
        self.net = nn.Sequential(nn.Linear(40,30), nn.ReLU())
    
    def forward(self, x):
        return self.net(x)

net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())

X = torch.rand(2, 40)
print(net)
net(X)

输出结果

技术图片

# 4.2 模型参数的访问、初始化和共享
import torch
from torch import nn
from torch.nn import init

net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) 

print(net)
X = torch.rand(2, 4)
Y = net(X).sum()

输出结果

技术图片

# 访问模型参数
print(type(net.named_parameters()))
for name, param in net.named_parameters():
    print(name, param.size())

输出结果

技术图片

深度学习之构造模型,访问模型参数——2020.3.11

标签:info   back   red   包含   简便   super   val   sum   fun   

原文地址:https://www.cnblogs.com/somedayLi/p/12461476.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!