码迷,mamicode.com
首页 > 其他好文 > 详细

PyTorch实现循环神经网络

时间:2020-02-14 22:12:59      阅读:96      评论:0      收藏:0      [点我收藏+]

标签:available   ORC   and   argmax   ack   tensor   data   init   __init__   

1 import torch
2 import torch.nn as nn
3 import time
4 import math
5 import sys
6 sys.path.append("/home/kesci/input")
7 import d2l_jay9460 as d2l
8 (corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()
9 device = torch.device(cuda if torch.cuda.is_available() else cpu)
1 rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)
2 num_steps, batch_size = 35, 2
3 X = torch.rand(num_steps, batch_size, vocab_size)
4 state = None
5 Y, state_new = rnn_layer(X, state)
6 print(Y.shape, state_new.shape)
 1 class RNNModel(nn.Module):
 2     def __init__(self, rnn_layer, vocab_size):
 3         super(RNNModel, self).__init__()
 4         self.rnn = rnn_layer
 5         self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
 6         self.vocab_size = vocab_size
 7         self.dense = nn.Linear(self.hidden_size, vocab_size)
 8 
 9     def forward(self, inputs, state):
10         # inputs.shape: (batch_size, num_steps)
11         X = to_onehot(inputs, vocab_size)
12         X = torch.stack(X)  # X.shape: (num_steps, batch_size, vocab_size)
13         hiddens, state = self.rnn(X, state)
14         hiddens = hiddens.view(-1, hiddens.shape[-1])  # hiddens.shape: (num_steps * batch_size, hidden_size)
15         output = self.dense(hiddens)
16         return output, state
 1 def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char,
 2                       char_to_idx):
 3     state = None
 4     output = [char_to_idx[prefix[0]]]  # output记录prefix加上预测的num_chars个字符
 5     for t in range(num_chars + len(prefix) - 1):
 6         X = torch.tensor([output[-1]], device=device).view(1, 1)
 7         (Y, state) = model(X, state)  # 前向计算不需要传入模型参数
 8         if t < len(prefix) - 1:
 9             output.append(char_to_idx[prefix[t + 1]])
10         else:
11             output.append(Y.argmax(dim=1).item())
12     return ‘‘.join([idx_to_char[i] for i in output])

PyTorch实现循环神经网络

标签:available   ORC   and   argmax   ack   tensor   data   init   __init__   

原文地址:https://www.cnblogs.com/hahasd/p/12309767.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!