码迷,mamicode.com
首页 > 编程语言 > 详细

pytorch --Rnn语言模型 -- 《Recurrent neural network based language model》

时间:2019-11-09 17:51:24      阅读:99      评论:0      收藏:0      [点我收藏+]

标签:split()   its   www   batch   org   res   ret   mil   oss   

论文通过实现RNN来完成了文本分类。

论文地址:88888888

模型结构图:

      技术图片

 

 

 原理自行参考论文,code and comment:

 1 # -*- coding: utf-8 -*-
 2 # @time : 2019/11/9  15:12
 3 
 4 import numpy as np
 5 import torch
 6 import torch.nn as nn
 7 import torch.optim as optim
 8 from torch.autograd import Variable
 9 
10 dtype = torch.FloatTensor
11 
12 sentences = [ "i like dog", "i love coffee", "i hate milk"]
13 
14 word_list = " ".join(sentences).split()
15 word_list = list(set(word_list))
16 word_dict = {w: i for i, w in enumerate(word_list)}
17 number_dict = {i: w for i, w in enumerate(word_list)}
18 n_class = len(word_dict)
19 
20 # TextRNN Parameter
21 batch_size = len(sentences)
22 n_step = 2 # number of cells(= number of Step)
23 n_hidden = 5 # number of hidden units in one cell
24 
25 def make_batch(sentences):
26     input_batch = []
27     target_batch = []
28 
29     for sen in sentences:
30         word = sen.split()
31         input = [word_dict[n] for n in word[:-1]]
32         target = word_dict[word[-1]]
33 
34         input_batch.append(np.eye(n_class)[input])
35         target_batch.append(target)
36 
37     return input_batch, target_batch
38 
39 # to Torch.Tensor
40 input_batch, target_batch = make_batch(sentences)
41 input_batch = Variable(torch.Tensor(input_batch))
42 target_batch = Variable(torch.LongTensor(target_batch))
43 
44 class TextRNN(nn.Module):
45     def __init__(self):
46         super(TextRNN, self).__init__()
47 
48         self.rnn = nn.RNN(input_size=n_class, hidden_size=n_hidden,batch_first=True)
49         self.W = nn.Parameter(torch.randn([n_hidden, n_class]).type(dtype))
50         self.b = nn.Parameter(torch.randn([n_class]).type(dtype))
51 
52     def forward(self, hidden, X):
53         if self.rnn.batch_first == True:
54             # X [batch_size,time_step,word_vector]
55             outputs, hidden = self.rnn(X, hidden)
56 
57             # outputs [batch_size, time_step, hidden_size*num_directions]
58             output = outputs[:, -1, :]  # [batch_size, num_directions(=1) * n_hidden]
59             model = torch.mm(output, self.W) + self.b  # model : [batch_size, n_class]
60             return model
61         else:
62             X = X.transpose(0, 1) # X : [n_step, batch_size, n_class]
63             outputs, hidden = self.rnn(X, hidden)
64             # outputs : [n_step, batch_size, num_directions(=1) * n_hidden]
65             # hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
66 
67             output = outputs[-1,:,:] # [batch_size, num_directions(=1) * n_hidden]
68             model = torch.mm(output, self.W) + self.b # model : [batch_size, n_class]
69             return model
70 
71 model = TextRNN()
72 
73 criterion = nn.CrossEntropyLoss()
74 optimizer = optim.Adam(model.parameters(), lr=0.001)
75 
76 # Training
77 for epoch in range(5000):
78     optimizer.zero_grad()
79 
80     # hidden : [num_layers * num_directions, batch, hidden_size]
81     hidden = Variable(torch.zeros(1, batch_size, n_hidden))
82     # input_batch : [batch_size, n_step, n_class]
83     output = model(hidden, input_batch)
84 
85     # output : [batch_size, n_class], target_batch : [batch_size] (LongTensor, not one-hot)
86     loss = criterion(output, target_batch)
87     if (epoch + 1) % 1000 == 0:
88         print(Epoch:, %04d % (epoch + 1), cost =, {:.6f}.format(loss))
89 
90     loss.backward()
91     optimizer.step()
92 
93 
94 # Predict
95 hidden_initial = Variable(torch.zeros(1, batch_size, n_hidden))
96 predict = model(hidden_initial, input_batch).data.max(1, keepdim=True)[1]
97 print([sen.split()[:2] for sen in sentences], ->, [number_dict[n.item()] for n in predict.squeeze()])

 

pytorch --Rnn语言模型 -- 《Recurrent neural network based language model》

标签:split()   its   www   batch   org   res   ret   mil   oss   

原文地址:https://www.cnblogs.com/dhName/p/11826541.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!