标签:info back red 包含 简便 super val sum fun
今天主要学习了利用torch中的nn模块定义Module
类,下面的代码包含对于模型类的构建以及参数访问,简便的可以使用‘net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())’构建模型,默认进行初始化。
# 3.1 构造模型
import torch
from torch import nn #module类是nn模块里提供的一个模型构造类
# 定义MLP类
class MLP(nn.Module):
def __init__(self, **kwargs):
super(MLP, self).__init__(**kwargs) #重载MLP类
self.hidden = nn.Linear(784, 256)
self.act = nn.ReLU()
self.output = nn.Linear(256, 10)
# 定义前向计算,反向传播函数可通过生成反向传播所需的backward函数
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
# 初始化net并传入输入数据x,做前向计算
X = torch.rand(2, 784)
net = MLP()
net(X)
# 4.12 module 的子类
class MySquential(nn.Module):
from collections import OrderedDict
def __init__(self, *args):
super(MySquential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key,module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
for module in self._modules.values():
input = module(input)
return input
net = MySquential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10),)
print(net)
net(X)
输出结果
# ModuleLise 类
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10))
print(net[0]) #使用Listd的索引访问
print(net)
输出结果
# ModuleDict类
net = nn.ModuleDict({'linear' : nn.Linear(784, 256), 'act' : nn.ReLU(),})
net['output'] = nn.Linear(256, 10)
print(net['linear']) # 访问
print(net.output)
print(net)
# 构造复杂模型
class FancyMLP(nn.Module):
def __init__(self, **kwargs):
super(FancyMLP, self).__init__(**kwargs)
self.rand_weight = torch.rand((20, 20),requires_grad=False)
self.linear = nn.Linear(20, 20)
def forward(self, x):
x = self.linear(x)
x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)
x = self.linear(x)
while x.norm().item() > 1:
x /= 2
if x.norm().item() < 0.0:
x *= 10
return x.sum()
X = torch.rand(2, 20)
net = FancyMLP()
print(net)
net(X)
输出结果
# 嵌套调用FancyMLP和Sequential类
class NestMLP(nn.Module):
def __init__(self, **kwargs):
super(NestMLP, self).__init__(**kwargs)
self.net = nn.Sequential(nn.Linear(40,30), nn.ReLU())
def forward(self, x):
return self.net(x)
net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())
X = torch.rand(2, 40)
print(net)
net(X)
输出结果
# 4.2 模型参数的访问、初始化和共享
import torch
from torch import nn
from torch.nn import init
net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1))
print(net)
X = torch.rand(2, 4)
Y = net(X).sum()
输出结果
# 访问模型参数
print(type(net.named_parameters()))
for name, param in net.named_parameters():
print(name, param.size())
输出结果
标签:info back red 包含 简便 super val sum fun
原文地址:https://www.cnblogs.com/somedayLi/p/12461476.html