码迷,mamicode.com
首页 > 其他好文 > 详细

PyTorch实现Softmax代码

时间:2020-02-14 22:19:49      阅读:256      评论:0      收藏:0      [点我收藏+]

标签:war   imp   初始化   nes   ict   version   layer   seq   ant   

 1 # 加载各种包或者模块
 2 import torch
 3 from torch import nn
 4 from torch.nn import init
 5 import numpy as np
 6 import sys
 7 sys.path.append("/home/kesci/input")
 8 import d2lzh1981 as d2l
 9 
10 print(torch.__version__)
1 # 初始化参数和获取数据
2 
3 batch_size = 256
4 train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, root=/home/kesci/input/FashionMNIST2065)
 1 num_inputs = 784
 2 num_outputs = 10
 3 
 4 class LinearNet(nn.Module):
 5     def __init__(self, num_inputs, num_outputs):
 6         super(LinearNet, self).__init__()
 7         self.linear = nn.Linear(num_inputs, num_outputs)
 8     def forward(self, x): # x 的形状: (batch, 1, 28, 28)
 9         y = self.linear(x.view(x.shape[0], -1))
10         return y
11     
12 # net = LinearNet(num_inputs, num_outputs)
13 
14 class FlattenLayer(nn.Module):
15     def __init__(self):
16         super(FlattenLayer, self).__init__()
17     def forward(self, x): # x 的形状: (batch, *, *, ...)
18         return x.view(x.shape[0], -1)
19 
20 from collections import OrderedDict
21 net = nn.Sequential(
22         # FlattenLayer(),
23         # LinearNet(num_inputs, num_outputs) 
24         OrderedDict([
25            (flatten, FlattenLayer()),
26            (linear, nn.Linear(num_inputs, num_outputs))]) # 或者写成我们自己定义的 LinearNet(num_inputs, num_outputs) 也可以
27         )
 1 # 初始化模型参数
 2 init.normal_(net.linear.weight, mean=0, std=0.01)
 3 init.constant_(net.linear.bias, val=0)
 4 
 5 # 定义损失函数
 6 loss = nn.CrossEntropyLoss() # 下面是他的函数原型
 7 # class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction=‘mean‘)
 8 
 9 # 定义优化函数
10 optimizer = torch.optim.SGD(net.parameters(), lr=0.1) # 下面是函数原型
11 # class torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)
12 
13 # 训练
14 num_epochs = 5
15 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

 

PyTorch实现Softmax代码

标签:war   imp   初始化   nes   ict   version   layer   seq   ant   

原文地址:https://www.cnblogs.com/hahasd/p/12309695.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!