码迷,mamicode.com
首页 > 其他好文 > 详细

Pytorch-seq2seq机器翻译模型(不含attention和含attention两个版本)

时间:2020-09-17 16:27:41      阅读:29      评论:0      收藏:0      [点我收藏+]

标签:分割   sum   target   model   数据   汉字   线性   lis   假设   

由于语料短,训练时间也短,模型性能不好,以下演示过程。

语料链接:

数据格式如图(先英文,再空格,再繁体中文):

技术图片

以下代码运行在Google Colab上。 

导包:

 1 import os
 2 import sys
 3 import math
 4 from collections import Counter
 5 import numpy as np
 6 import random
 7 
 8 import torch
 9 import torch.nn as nn
10 import torch.nn.functional as F
11 
12 import nltk
13 nltk.download(punkt)

1.数据预处理

1.1读入中英文数据

  • 英文使用nltk的word tokenizer来分词,并且使用小写字母
  • 中文直接使用单个汉字作为基本单元
 1 def load_data(in_file):
 2     cn = []
 3     en = []
 4     num_examples = 0
 5     with open(in_file, r) as f:
 6         for line in f:
 7             line = line.strip().split("\t")
 8             
 9             en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
10             cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
11     return en, cn
12 
13 train_file = "nmt/en-cn/train.txt"
14 dev_file = "nmt/en-cn/dev.txt"
15 train_en, train_cn = load_data(train_file)
16 dev_en, dev_cn = load_data(dev_file)

查看返回的数据内容:

1 print(dev_en[:2])
2 print(dev_cn[:2])

[[‘BOS‘, ‘she‘, ‘put‘, ‘the‘, ‘magazine‘, ‘on‘, ‘the‘, ‘table‘, ‘.‘, ‘EOS‘], [‘BOS‘, ‘hey‘, ‘,‘, ‘what‘, ‘are‘, ‘you‘, ‘doing‘, ‘here‘, ‘?‘, ‘EOS‘]]

[[‘BOS‘, ‘她‘, ‘把‘, ‘雜‘, ‘誌‘, ‘放‘, ‘在‘, ‘桌‘, ‘上‘, ‘。‘, ‘EOS‘], [‘BOS‘, ‘嘿‘, ‘,‘, ‘你‘, ‘在‘, ‘這‘, ‘做‘, ‘什‘, ‘麼‘, ‘?‘, ‘EOS‘]]

1.2构建单词表

 1 UNK_IDX = 0
 2 PAD_IDX = 1
 3 def build_dict(sentences, max_words=50000):
 4     word_count = Counter()
 5     for sentence in sentences:
 6         for s in sentence:
 7             word_count[s] += 1
 8     ls = word_count.most_common(max_words)
 9     total_words = len(ls) + 2
10     word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
11     word_dict["UNK"] = UNK_IDX
12     word_dict["PAD"] = PAD_IDX
13     return word_dict, total_words      #total_words所有单词数,最大50002
14 
15 en_dict, en_total_words = build_dict(train_en)
16 cn_dict, cn_total_words = build_dict(train_cn)
17 inv_en_dict = {v: k for k, v in en_dict.items()}    #英文:索引到单词
18 inv_cn_dict = {v: k for k, v in cn_dict.items()}    #中文:索引到字

1.3把单词全部转变成数字

sort_by_len=True的目的是为了使得一个batch中的句子长度差不多,所以按长度排序。

 1 def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):        
 2 
 3     length = len(en_sentences)
 4     out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
 5     out_cn_sentences = [[cn_dict.get(w, 0) for w in sent] for sent in cn_sentences]
 6 
 7     # sort sentences by word lengths
 8     def len_argsort(seq):
 9         return sorted(range(len(seq)), key=lambda x: len(seq[x]))
10        
11     # 把中文和英文按照同样的顺序排序
12     if sort_by_len:
13         sorted_index = len_argsort(out_en_sentences)
14         out_en_sentences = [out_en_sentences[i] for i in sorted_index]
15         out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
16         
17     return out_en_sentences, out_cn_sentences
18 
19 train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
20 dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)

查看返回的数据内容:

1 print(train_cn[2])
2 print([inv_cn_dict[i] for i in train_cn[2]])
3 print([inv_en_dict[i] for i in train_en[2]])

[2, 982, 2028, 8, 4, 3]

[‘BOS‘, ‘祝‘, ‘贺‘, ‘你‘, ‘。‘, ‘EOS‘]

[‘BOS‘, ‘congratulations‘, ‘!‘, ‘EOS‘]

1.4把全部句子分成batch

1 def get_minibatches(n, minibatch_size, shuffle=True):  #n是传进来的句子数
2     idx_list = np.arange(0, n, minibatch_size)   #[0, 1, ..., n-1]按minibatch_size大小分割
3     if shuffle:
4         np.random.shuffle(idx_list)
5     minibatches = []
6     for idx in idx_list:
7         minibatches.append(np.arange(idx, min(idx + minibatch_size, n)))
8     return minibatches

查看上面函数的功能:

1 get_minibatches(100, 15)
2 [array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]),
3  array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]),
4  array([75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]),
5  array([45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]),
6  array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]),
7  array([90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
8  array([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])]
 1 def prepare_data(seqs):   #seqs传入的是minibatches中的一个minibatch对应的batch_size个句子索引(嵌套列表),此处batch_size=64
 2 
 3     lengths = [len(seq) for seq in seqs]   
 4     n_samples = len(seqs)            
 5     max_len = np.max(lengths)  #batch_size个句子中最长句子长度
 6 
 7     x = np.zeros((n_samples, max_len)).astype(int32)
 8     x_lengths = np.array(lengths).astype("int32")
 9     for idx, seq in enumerate(seqs):
10         x[idx, :lengths[idx]] = seq
11     return x, x_lengths             
12 
13 def gen_examples(en_sentences, cn_sentences, batch_size):
14     minibatches = get_minibatches(len(en_sentences), batch_size)
15     all_ex = []
16     for minibatch in minibatches:
17         mb_en_sentences = [en_sentences[t] for t in minibatch]
18         mb_cn_sentences = [cn_sentences[t] for t in minibatch]
19         mb_x, mb_x_len = prepare_data(mb_en_sentences)
20         mb_y, mb_y_len = prepare_data(mb_cn_sentences)
21         all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
22     return all_ex     #返回内容依次是batch_size个英文句子索引,英文句子长度,中文句子索引,中文句子长度
23 
24 batch_size = 64
25 train_data = gen_examples(train_en, train_cn, batch_size)
26 dev_data = gen_examples(dev_en, dev_cn, batch_size)

2.Encoder Decoder模型(没有Attention版本)

2.1定义计算损失的函数

 1 # masked cross entropy loss
 2 class LanguageModelCriterion(nn.Module):
 3     def __init__(self):
 4         super(LanguageModelCriterion, self).__init__()
 5 
 6     def forward(self, input, target, mask):   #把mask的部分忽略掉
 7         # input: (batch_size * seq_len) * vocab_size
 8         input = input.contiguous().view(-1, input.size(2))
 9         # target: batch_size * 1
10         target = target.contiguous().view(-1, 1)
11         mask = mask.contiguous().view(-1, 1)
12         output = -input.gather(1, target) * mask
13         output = torch.sum(output) / torch.sum(mask)
14 
15         return output

2.2Encoder部分

Encoder模型的任务是把输入文字传入embedding层和GRU层,转换成一些hidden states作为后续的context vectors;

对nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence的理解:http://www.mamicode.com/info-detail-2493083.html

 1 class PlainEncoder(nn.Module):
 2     def __init__(self, vocab_size, hidden_size, dropout=0.2):       #假设embedding_size=hidden_size
 3         super(PlainEncoder, self).__init__()
 4         self.embed = nn.Embedding(vocab_size, hidden_size)
 5         self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
 6         self.dropout = nn.Dropout(dropout)
 7 
 8     def forward(self, x, lengths):   #最后一个hidden_state要取出来作为context vector,所以需要lengths
 9         sorted_len, sorted_idx = lengths.sort(0, descending=True)   #把batch里面的seq按照长度降序排列
10         x_sorted = x[sorted_idx.long()]
11         embedded = self.dropout(self.embed(x_sorted))
12         
13         #句子padding到一样长度的(真实句长会比padding的短),为了rnn时能取到真实长度的最后状态,先pack_padded_sequence进行处理
14         packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
15         packed_out, hid = self.rnn(packed_embedded)
16         out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)    #回到padding长度
17         
18         _, original_idx = sorted_idx.sort(0, descending=False)                     #排序回原来的样子
19         out = out[original_idx.long()].contiguous()
20         hid = hid[:, original_idx.long()].contiguous()
21         
22         return out, hid[[-1]]   #hid[[-1]]相当于out[:, -1]

2.3Decoder部分

Decoder会根据已经翻译的句子内容和context vectors,来决定下一个输出的单词;

 1 class PlainDecoder(nn.Module):
 2     def __init__(self, vocab_size, hidden_size, dropout=0.2):
 3         super(PlainDecoder, self).__init__()
 4         self.embed = nn.Embedding(vocab_size, hidden_size)
 5         self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
 6         self.fc = nn.Linear(hidden_size, vocab_size)
 7         self.dropout = nn.Dropout(dropout)
 8         
 9     def forward(self, y, y_lengths, hid):    #和PlainEncoder的forward过程大致差不多,区别在于hidden_state不是0而是传入的
10         sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
11         y_sorted = y[sorted_idx.long()]
12         hid = hid[:, sorted_idx.long()]
13 
14         y_sorted = self.dropout(self.embed(y_sorted))             #[batch_size, y_lengths, embed_size=hidden_size]
15         
16         packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
17         out, hid = self.rnn(packed_seq, hid)
18         unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
19 
20         _, original_idx = sorted_idx.sort(0, descending=False)
21         output_seq = unpacked[original_idx.long()].contiguous()   #[batch_size, y_lengths, hidden_size]
22         hid = hid[:, original_idx.long()].contiguous()            #[1, batch_size, hidden_size]
23 
24         output = F.log_softmax(self.fc(output_seq), -1)           #[batch_size, y_lengths, vocab_size]->
25         
26         return output, hid

2.4构建Seq2Seq模型

构建Seq2Seq模型把encoder, attention, decoder串到一起;

 1 class PlainSeq2Seq(nn.Module):
 2     def __init__(self, encoder, decoder):
 3         super(PlainSeq2Seq, self).__init__()
 4         self.encoder = encoder
 5         self.decoder = decoder
 6         
 7     def forward(self, x, x_lengths, y, y_lengths):
 8         encoder_out, hid = self.encoder(x, x_lengths)
 9         output, hid = self.decoder(y, y_lengths, hid)
10         return output, None
11     
12     def translate(self, x, x_lengths, y, max_length=10):
13         encoder_out, hid = self.encoder(x, x_lengths)
14         preds = []
15         batch_size = x.shape[0]
16         attns = []
17         for i in range(max_length):
18             output, hid = self.decoder(y=y, y_lengths=torch.ones(batch_size).long().to(y.device), hid=hid)
19             y = output.max(2)[1].view(batch_size, 1)
20             preds.append(y)
21             
22         return torch.cat(preds, 1), None

3.创建模型

定义模型、损失、优化器。

1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2 dropout = 0.2
3 hidden_size = 100
4 encoder = PlainEncoder(vocab_size=en_total_words, hidden_size=hidden_size, dropout=dropout)
5 decoder = PlainDecoder(vocab_size=cn_total_words, hidden_size=hidden_size, dropout=dropout)
6 model = PlainSeq2Seq(encoder, decoder)
7 model = model.to(device)
8 loss_fn = LanguageModelCriterion().to(device)
9 optimizer = torch.optim.Adam(model.parameters())

4.评估模型

 1 def evaluate(model, data):
 2     model.eval()
 3     total_num_words = total_loss = 0.
 4     with torch.no_grad():
 5         for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
 6             mb_x = torch.from_numpy(mb_x).to(device).long()
 7             mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
 8             mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
 9             mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
10             mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
11             mb_y_len[mb_y_len<=0] = 1
12 
13             mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
14 
15             mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
16             mb_out_mask = mb_out_mask.float()
17 
18             loss = loss_fn(mb_pred, mb_output, mb_out_mask)
19 
20             num_words = torch.sum(mb_y_len).item()
21             total_loss += loss.item() * num_words
22             total_num_words += num_words
23     print("Evaluation loss", total_loss/total_num_words)

5.训练模型

 1 def train(model, data, num_epochs=20):
 2     for epoch in range(num_epochs):
 3         model.train()
 4         total_num_words = total_loss = 0.
 5         for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
 6             mb_x = torch.from_numpy(mb_x).to(device).long()
 7             mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
 8             mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
 9             mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
10             mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
11             mb_y_len[mb_y_len<=0] = 1
12             
13             mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
14             
15             mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
16             mb_out_mask = mb_out_mask.float()
17             
18             loss = loss_fn(mb_pred, mb_output, mb_out_mask)
19             
20             num_words = torch.sum(mb_y_len).item()
21             total_loss += loss.item() * num_words
22             total_num_words += num_words
23             
24             # 更新模型
25             optimizer.zero_grad()
26             loss.backward()
27             torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
28             optimizer.step()
29             
30             if it % 100 == 0:
31                 print("Epoch", epoch, "iteration", it, "loss", loss.item())
32 
33                 
34         print("Epoch", epoch, "Training loss", total_loss/total_num_words)
35         if epoch % 5 == 0:
36             evaluate(model, dev_data)

训练100次:

1 train(model, train_data, num_epochs=100)

训练结果(training loss在不断下降):

技术图片
  1 Epoch 0 iteration 0 loss 8.084440231323242
  2 Epoch 0 iteration 100 loss 4.8448944091796875
  3 Epoch 0 iteration 200 loss 4.879772663116455
  4 Epoch 0 Training loss 5.477221919210141
  5 Evaluation loss 4.821030395389826
  6 Epoch 1 iteration 0 loss 4.69868278503418
  7 Epoch 1 iteration 100 loss 4.085171699523926
  8 Epoch 1 iteration 200 loss 4.312857151031494
  9 Epoch 1 Training loss 4.579521701350524
 10 Epoch 2 iteration 0 loss 4.193971633911133
 11 Epoch 2 iteration 100 loss 3.678673267364502
 12 Epoch 2 iteration 200 loss 4.019515514373779
 13 Epoch 2 Training loss 4.186071368925457
 14 Epoch 3 iteration 0 loss 3.8352835178375244
 15 Epoch 3 iteration 100 loss 3.3954527378082275
 16 Epoch 3 iteration 200 loss 3.774580240249634
 17 Epoch 3 Training loss 3.9222166424267986
 18 Epoch 4 iteration 0 loss 3.585063934326172
 19 Epoch 4 iteration 100 loss 3.215750217437744
 20 Epoch 4 iteration 200 loss 3.626997232437134
 21 Epoch 4 Training loss 3.722608096150466
 22 Epoch 5 iteration 0 loss 3.411375045776367
 23 Epoch 5 iteration 100 loss 3.0424859523773193
 24 Epoch 5 iteration 200 loss 3.492255926132202
 25 Epoch 5 Training loss 3.5699179079587195
 26 Evaluation loss 3.655821240952787
 27 Epoch 6 iteration 0 loss 3.273927927017212
 28 Epoch 6 iteration 100 loss 2.897022247314453
 29 Epoch 6 iteration 200 loss 3.355715036392212
 30 Epoch 6 Training loss 3.4411540739967426
 31 Epoch 7 iteration 0 loss 3.16508412361145
 32 Epoch 7 iteration 100 loss 2.7818763256073
 33 Epoch 7 iteration 200 loss 3.241000175476074
 34 Epoch 7 Training loss 3.330995073153501
 35 Epoch 8 iteration 0 loss 3.081458806991577
 36 Epoch 8 iteration 100 loss 2.692844867706299
 37 Epoch 8 iteration 200 loss 3.159105062484741
 38 Epoch 8 Training loss 3.237538761219645
 39 Epoch 9 iteration 0 loss 2.983361005783081
 40 Epoch 9 iteration 100 loss 2.5852301120758057
 41 Epoch 9 iteration 200 loss 3.076793670654297
 42 Epoch 9 Training loss 3.1542968146839754
 43 Epoch 10 iteration 0 loss 2.88155198097229
 44 Epoch 10 iteration 100 loss 2.504387617111206
 45 Epoch 10 iteration 200 loss 2.9708898067474365
 46 Epoch 10 Training loss 3.0766581801071924
 47 Evaluation loss 3.3804360915245204
 48 Epoch 11 iteration 0 loss 2.805739164352417
 49 Epoch 11 iteration 100 loss 2.417832612991333
 50 Epoch 11 iteration 200 loss 2.9001076221466064
 51 Epoch 11 Training loss 3.0072335865815747
 52 Epoch 12 iteration 0 loss 2.7389864921569824
 53 Epoch 12 iteration 100 loss 2.352132558822632
 54 Epoch 12 iteration 200 loss 2.864527702331543
 55 Epoch 12 Training loss 2.945309993148362
 56 Epoch 13 iteration 0 loss 2.6841001510620117
 57 Epoch 13 iteration 100 loss 2.2722346782684326
 58 Epoch 13 iteration 200 loss 2.8002915382385254
 59 Epoch 13 Training loss 2.8879525671218156
 60 Epoch 14 iteration 0 loss 2.641491651535034
 61 Epoch 14 iteration 100 loss 2.237807273864746
 62 Epoch 14 iteration 200 loss 2.7538034915924072
 63 Epoch 14 Training loss 2.833802188663957
 64 Epoch 15 iteration 0 loss 2.5613601207733154
 65 Epoch 15 iteration 100 loss 2.149299144744873
 66 Epoch 15 iteration 200 loss 2.671037435531616
 67 Epoch 15 Training loss 2.7850014679518598
 68 Evaluation loss 3.2569677577366516
 69 Epoch 16 iteration 0 loss 2.5330140590667725
 70 Epoch 16 iteration 100 loss 2.0988974571228027
 71 Epoch 16 iteration 200 loss 2.611022472381592
 72 Epoch 16 Training loss 2.7354116963192716
 73 Epoch 17 iteration 0 loss 2.485084295272827
 74 Epoch 17 iteration 100 loss 2.0532665252685547
 75 Epoch 17 iteration 200 loss 2.604226589202881
 76 Epoch 17 Training loss 2.6934350694497957
 77 Epoch 18 iteration 0 loss 2.4521820545196533
 78 Epoch 18 iteration 100 loss 2.0395381450653076
 79 Epoch 18 iteration 200 loss 2.5578808784484863
 80 Epoch 18 Training loss 2.651303096776386
 81 Epoch 19 iteration 0 loss 2.390338182449341
 82 Epoch 19 iteration 100 loss 1.9780246019363403
 83 Epoch 19 iteration 200 loss 2.5150232315063477
 84 Epoch 19 Training loss 2.611681331448251
 85 Epoch 20 iteration 0 loss 2.352649211883545
 86 Epoch 20 iteration 100 loss 1.9426053762435913
 87 Epoch 20 iteration 200 loss 2.4782586097717285
 88 Epoch 20 Training loss 2.5747013451744616
 89 Evaluation loss 3.194680030596711
 90 Epoch 21 iteration 0 loss 2.3205008506774902
 91 Epoch 21 iteration 100 loss 1.9143742322921753
 92 Epoch 21 iteration 200 loss 2.4607479572296143
 93 Epoch 21 Training loss 2.5404243457594116
 94 Epoch 22 iteration 0 loss 2.3100969791412354
 95 Epoch 22 iteration 100 loss 1.912932276725769
 96 Epoch 22 iteration 200 loss 2.4103682041168213
 97 Epoch 22 Training loss 2.507626390779296
 98 Epoch 23 iteration 0 loss 2.228956699371338
 99 Epoch 23 iteration 100 loss 1.8543353080749512
100 Epoch 23 iteration 200 loss 2.3663489818573
101 Epoch 23 Training loss 2.475231424650597
102 Epoch 24 iteration 0 loss 2.199277639389038
103 Epoch 24 iteration 100 loss 1.8272788524627686
104 Epoch 24 iteration 200 loss 2.3518714904785156
105 Epoch 24 Training loss 2.4439996520576863
106 Epoch 25 iteration 0 loss 2.198460817337036
107 Epoch 25 iteration 100 loss 1.7921738624572754
108 Epoch 25 iteration 200 loss 2.3299384117126465
109 Epoch 25 Training loss 2.416539151404694
110 Evaluation loss 3.1583419660450347
111 Epoch 26 iteration 0 loss 2.1647706031799316
112 Epoch 26 iteration 100 loss 1.725657343864441
113 Epoch 26 iteration 200 loss 2.268852710723877
114 Epoch 26 Training loss 2.3919890312051444
115 Epoch 27 iteration 0 loss 2.1400880813598633
116 Epoch 27 iteration 100 loss 1.7474910020828247
117 Epoch 27 iteration 200 loss 2.256742000579834
118 Epoch 27 Training loss 2.3595162004913086
119 Epoch 28 iteration 0 loss 2.0979115962982178
120 Epoch 28 iteration 100 loss 1.7000322341918945
121 Epoch 28 iteration 200 loss 2.2546005249023438
122 Epoch 28 Training loss 2.3335356415568618
123 Epoch 29 iteration 0 loss 2.1031572818756104
124 Epoch 29 iteration 100 loss 1.6599613428115845
125 Epoch 29 iteration 200 loss 2.2020833492279053
126 Epoch 29 Training loss 2.311978717884133
127 Epoch 30 iteration 0 loss 2.041980028152466
128 Epoch 30 iteration 100 loss 1.6663353443145752
129 Epoch 30 iteration 200 loss 2.1463098526000977
130 Epoch 30 Training loss 2.2902015222655807
131 Evaluation loss 3.133273747140961
132 Epoch 31 iteration 0 loss 2.0045719146728516
133 Epoch 31 iteration 100 loss 1.6515719890594482
134 Epoch 31 iteration 200 loss 2.1130664348602295
135 Epoch 31 Training loss 2.2633183437027657
136 Epoch 32 iteration 0 loss 1.9948643445968628
137 Epoch 32 iteration 100 loss 1.6262538433074951
138 Epoch 32 iteration 200 loss 2.1329450607299805
139 Epoch 32 Training loss 2.242057023454951
140 Epoch 33 iteration 0 loss 1.9623773097991943
141 Epoch 33 iteration 100 loss 1.6022558212280273
142 Epoch 33 iteration 200 loss 2.092766523361206
143 Epoch 33 Training loss 2.219300144243463
144 Epoch 34 iteration 0 loss 1.929176688194275
145 Epoch 34 iteration 100 loss 1.57985258102417
146 Epoch 34 iteration 200 loss 2.067972183227539
147 Epoch 34 Training loss 2.199957146669663
148 Epoch 35 iteration 0 loss 1.9449653625488281
149 Epoch 35 iteration 100 loss 1.5760831832885742
150 Epoch 35 iteration 200 loss 2.056731939315796
151 Epoch 35 Training loss 2.1790822226814464
152 Evaluation loss 3.13363336627263
153 Epoch 36 iteration 0 loss 1.8961074352264404
154 Epoch 36 iteration 100 loss 1.5195672512054443
155 Epoch 36 iteration 200 loss 2.0268213748931885
156 Epoch 36 Training loss 2.160204240618562
157 Epoch 37 iteration 0 loss 1.9172203540802002
158 Epoch 37 iteration 100 loss 1.495902180671692
159 Epoch 37 iteration 200 loss 1.9827772378921509
160 Epoch 37 Training loss 2.139063811380212
161 Epoch 38 iteration 0 loss 1.8988227844238281
162 Epoch 38 iteration 100 loss 1.5224453210830688
163 Epoch 38 iteration 200 loss 1.972291111946106
164 Epoch 38 Training loss 2.1211086652629887
165 Epoch 39 iteration 0 loss 1.8728121519088745
166 Epoch 39 iteration 100 loss 1.4476994276046753
167 Epoch 39 iteration 200 loss 1.9898269176483154
168 Epoch 39 Training loss 2.1024907934743258
169 Epoch 40 iteration 0 loss 1.8664008378982544
170 Epoch 40 iteration 100 loss 1.4997611045837402
171 Epoch 40 iteration 200 loss 1.9541966915130615
172 Epoch 40 Training loss 2.086313187411815
173 Evaluation loss 3.1282314096494708
174 Epoch 41 iteration 0 loss 1.865237832069397
175 Epoch 41 iteration 100 loss 1.4755399227142334
176 Epoch 41 iteration 200 loss 1.9337103366851807
177 Epoch 41 Training loss 2.068258631932244
178 Epoch 42 iteration 0 loss 1.790804147720337
179 Epoch 42 iteration 100 loss 1.4380069971084595
180 Epoch 42 iteration 200 loss 1.9523491859436035
181 Epoch 42 Training loss 2.0498001934027874
182 Epoch 43 iteration 0 loss 1.7979768514633179
183 Epoch 43 iteration 100 loss 1.436006784439087
184 Epoch 43 iteration 200 loss 1.9101322889328003
185 Epoch 43 Training loss 2.0354298580230195
186 Epoch 44 iteration 0 loss 1.7717180252075195
187 Epoch 44 iteration 100 loss 1.412601351737976
188 Epoch 44 iteration 200 loss 1.8883790969848633
189 Epoch 44 Training loss 2.0182710578663032
190 Epoch 45 iteration 0 loss 1.7614871263504028
191 Epoch 45 iteration 100 loss 1.3429900407791138
192 Epoch 45 iteration 200 loss 1.862486720085144
193 Epoch 45 Training loss 2.0034489605129595
194 Evaluation loss 3.13050353642062
195 Epoch 46 iteration 0 loss 1.753187656402588
196 Epoch 46 iteration 100 loss 1.3810824155807495
197 Epoch 46 iteration 200 loss 1.8526273965835571
198 Epoch 46 Training loss 1.9899710891643612
199 Epoch 47 iteration 0 loss 1.7567869424819946
200 Epoch 47 iteration 100 loss 1.3430988788604736
201 Epoch 47 iteration 200 loss 1.8135911226272583
202 Epoch 47 Training loss 1.9723690433387957
203 Epoch 48 iteration 0 loss 1.7263280153274536
204 Epoch 48 iteration 100 loss 1.3430798053741455
205 Epoch 48 iteration 200 loss 1.8229252099990845
206 Epoch 48 Training loss 1.9580909331705005
207 Epoch 49 iteration 0 loss 1.731834888458252
208 Epoch 49 iteration 100 loss 1.325390100479126
209 Epoch 49 iteration 200 loss 1.8075029850006104
210 Epoch 49 Training loss 1.9418853706725143
211 Epoch 50 iteration 0 loss 1.7218893766403198
212 Epoch 50 iteration 100 loss 1.2710607051849365
213 Epoch 50 iteration 200 loss 1.8196479082107544
214 Epoch 50 Training loss 1.9300463292027463
215 Evaluation loss 3.1402900424368902
216 Epoch 51 iteration 0 loss 1.701721429824829
217 Epoch 51 iteration 100 loss 1.2720820903778076
218 Epoch 51 iteration 200 loss 1.7759710550308228
219 Epoch 51 Training loss 1.9192517232508806
220 Epoch 52 iteration 0 loss 1.7286512851715088
221 Epoch 52 iteration 100 loss 1.2737478017807007
222 Epoch 52 iteration 200 loss 1.7545547485351562
223 Epoch 52 Training loss 1.906238278183267
224 Epoch 53 iteration 0 loss 1.6672327518463135
225 Epoch 53 iteration 100 loss 1.3138436079025269
226 Epoch 53 iteration 200 loss 1.8045201301574707
227 Epoch 53 Training loss 1.8922825534741075
228 Epoch 54 iteration 0 loss 1.617557168006897
229 Epoch 54 iteration 100 loss 1.22885262966156
230 Epoch 54 iteration 200 loss 1.7750707864761353
231 Epoch 54 Training loss 1.8807705430479014
232 Epoch 55 iteration 0 loss 1.66348135471344
233 Epoch 55 iteration 100 loss 1.2331219911575317
234 Epoch 55 iteration 200 loss 1.7303975820541382
235 Epoch 55 Training loss 1.867195544079556
236 Evaluation loss 3.145431456349013
237 Epoch 56 iteration 0 loss 1.6259342432022095
238 Epoch 56 iteration 100 loss 1.2141388654708862
239 Epoch 56 iteration 200 loss 1.6984847784042358
240 Epoch 56 Training loss 1.8548133653506713
241 Epoch 57 iteration 0 loss 1.605487585067749
242 Epoch 57 iteration 100 loss 1.1920335292816162
243 Epoch 57 iteration 200 loss 1.7253336906433105
244 Epoch 57 Training loss 1.8387836396466541
245 Epoch 58 iteration 0 loss 1.600136160850525
246 Epoch 58 iteration 100 loss 1.2192472219467163
247 Epoch 58 iteration 200 loss 1.6888371706008911
248 Epoch 58 Training loss 1.83046734055076
249 Epoch 59 iteration 0 loss 1.6042535305023193
250 Epoch 59 iteration 100 loss 1.2362377643585205
251 Epoch 59 iteration 200 loss 1.6654771566390991
252 Epoch 59 Training loss 1.8226244935892273
253 Epoch 60 iteration 0 loss 1.5602766275405884
254 Epoch 60 iteration 100 loss 1.201045036315918
255 Epoch 60 iteration 200 loss 1.6702684164047241
256 Epoch 60 Training loss 1.8102721190615219
257 Evaluation loss 3.154303393916162
258 Epoch 61 iteration 0 loss 1.5679781436920166
259 Epoch 61 iteration 100 loss 1.2105367183685303
260 Epoch 61 iteration 200 loss 1.6650742292404175
261 Epoch 61 Training loss 1.7970227477404426
262 Epoch 62 iteration 0 loss 1.5734565258026123
263 Epoch 62 iteration 100 loss 1.1602052450180054
264 Epoch 62 iteration 200 loss 1.583187222480774
265 Epoch 62 Training loss 1.787027303402099
266 Epoch 63 iteration 0 loss 1.563283920288086
267 Epoch 63 iteration 100 loss 1.1829460859298706
268 Epoch 63 iteration 200 loss 1.6458944082260132
269 Epoch 63 Training loss 1.7742324239103342
270 Epoch 64 iteration 0 loss 1.5429617166519165
271 Epoch 64 iteration 100 loss 1.1225509643554688
272 Epoch 64 iteration 200 loss 1.6353931427001953
273 Epoch 64 Training loss 1.7665018986396424
274 Epoch 65 iteration 0 loss 1.5284583568572998
275 Epoch 65 iteration 100 loss 1.1426113843917847
276 Epoch 65 iteration 200 loss 1.6138485670089722
277 Epoch 65 Training loss 1.7557591437816458
278 Evaluation loss 3.166533922994568
279 Epoch 66 iteration 0 loss 1.5184751749038696
280 Epoch 66 iteration 100 loss 1.127056360244751
281 Epoch 66 iteration 200 loss 1.611910343170166
282 Epoch 66 Training loss 1.7446940747065838
283 Epoch 67 iteration 0 loss 1.4880752563476562
284 Epoch 67 iteration 100 loss 1.1075133085250854
285 Epoch 67 iteration 200 loss 1.6138321161270142
286 Epoch 67 Training loss 1.7374662356132202
287 Epoch 68 iteration 0 loss 1.5260978937149048
288 Epoch 68 iteration 100 loss 1.12235689163208
289 Epoch 68 iteration 200 loss 1.6129950284957886
290 Epoch 68 Training loss 1.7253250324901928
291 Epoch 69 iteration 0 loss 1.5172449350357056
292 Epoch 69 iteration 100 loss 1.1174883842468262
293 Epoch 69 iteration 200 loss 1.551174283027649
294 Epoch 69 Training loss 1.7166664929363027
295 Epoch 70 iteration 0 loss 1.5006300210952759
296 Epoch 70 iteration 100 loss 1.0905342102050781
297 Epoch 70 iteration 200 loss 1.5446460247039795
298 Epoch 70 Training loss 1.70989819337649
299 Evaluation loss 3.1750113054724385
300 Epoch 71 iteration 0 loss 1.4726097583770752
301 Epoch 71 iteration 100 loss 1.086822509765625
302 Epoch 71 iteration 200 loss 1.5575647354125977
303 Epoch 71 Training loss 1.697000935158525
304 Epoch 72 iteration 0 loss 1.449334979057312
305 Epoch 72 iteration 100 loss 1.0667144060134888
306 Epoch 72 iteration 200 loss 1.530726671218872
307 Epoch 72 Training loss 1.6881878283419123
308 Epoch 73 iteration 0 loss 1.4603246450424194
309 Epoch 73 iteration 100 loss 1.0751914978027344
310 Epoch 73 iteration 200 loss 1.5088605880737305
311 Epoch 73 Training loss 1.6805761044806562
312 Epoch 74 iteration 0 loss 1.4748084545135498
313 Epoch 74 iteration 100 loss 1.0556395053863525
314 Epoch 74 iteration 200 loss 1.5206905603408813
315 Epoch 74 Training loss 1.6673887956853506
316 Epoch 75 iteration 0 loss 1.454646348953247
317 Epoch 75 iteration 100 loss 1.0396276712417603
318 Epoch 75 iteration 200 loss 1.518398404121399
319 Epoch 75 Training loss 1.6633919350661184
320 Evaluation loss 3.189181657332237
321 Epoch 76 iteration 0 loss 1.4616646766662598
322 Epoch 76 iteration 100 loss 0.9838554859161377
323 Epoch 76 iteration 200 loss 1.4613702297210693
324 Epoch 76 Training loss 1.6526747506920867
325 Epoch 77 iteration 0 loss 1.4646761417388916
326 Epoch 77 iteration 100 loss 1.0383753776550293
327 Epoch 77 iteration 200 loss 1.5081768035888672
328 Epoch 77 Training loss 1.6462943129725018
329 Epoch 78 iteration 0 loss 1.4008097648620605
330 Epoch 78 iteration 100 loss 1.0147686004638672
331 Epoch 78 iteration 200 loss 1.5017434358596802
332 Epoch 78 Training loss 1.6352284007247493
333 Epoch 79 iteration 0 loss 1.4189144372940063
334 Epoch 79 iteration 100 loss 1.0126101970672607
335 Epoch 79 iteration 200 loss 1.4195480346679688
336 Epoch 79 Training loss 1.628015456811747
337 Epoch 80 iteration 0 loss 1.4199804067611694
338 Epoch 80 iteration 100 loss 1.0256879329681396
339 Epoch 80 iteration 200 loss 1.4564563035964966
340 Epoch 80 Training loss 1.6227562783981957
341 Evaluation loss 3.2074876046135703
342 Epoch 81 iteration 0 loss 1.431972622871399
343 Epoch 81 iteration 100 loss 1.0110960006713867
344 Epoch 81 iteration 200 loss 1.4414775371551514
345 Epoch 81 Training loss 1.6157781071711008
346 Epoch 82 iteration 0 loss 1.4158073663711548
347 Epoch 82 iteration 100 loss 0.9702512621879578
348 Epoch 82 iteration 200 loss 1.4209394454956055
349 Epoch 82 Training loss 1.605166310639776
350 Epoch 83 iteration 0 loss 1.3871146440505981
351 Epoch 83 iteration 100 loss 1.0183656215667725
352 Epoch 83 iteration 200 loss 1.4292359352111816
353 Epoch 83 Training loss 1.5961119023327037
354 Epoch 84 iteration 0 loss 1.3919366598129272
355 Epoch 84 iteration 100 loss 0.9692129492759705
356 Epoch 84 iteration 200 loss 1.4092985391616821
357 Epoch 84 Training loss 1.5897755956223851
358 Epoch 85 iteration 0 loss 1.355398416519165
359 Epoch 85 iteration 100 loss 0.9916797280311584
360 Epoch 85 iteration 200 loss 1.423561453819275
361 Epoch 85 Training loss 1.5878568289810793
362 Evaluation loss 3.2138472480503295
363 Epoch 86 iteration 0 loss 1.351928472518921
364 Epoch 86 iteration 100 loss 0.9997824430465698
365 Epoch 86 iteration 200 loss 1.4049323797225952
366 Epoch 86 Training loss 1.5719682346027806
367 Epoch 87 iteration 0 loss 1.3508714437484741
368 Epoch 87 iteration 100 loss 0.9411044716835022
369 Epoch 87 iteration 200 loss 1.4019731283187866
370 Epoch 87 Training loss 1.5641802139809575
371 Epoch 88 iteration 0 loss 1.347946047782898
372 Epoch 88 iteration 100 loss 0.9493017792701721
373 Epoch 88 iteration 200 loss 1.3770906925201416
374 Epoch 88 Training loss 1.5587840858982533
375 Epoch 89 iteration 0 loss 1.320084571838379
376 Epoch 89 iteration 100 loss 0.9223963022232056
377 Epoch 89 iteration 200 loss 1.4065088033676147
378 Epoch 89 Training loss 1.5548267858027334
379 Epoch 90 iteration 0 loss 1.3534889221191406
380 Epoch 90 iteration 100 loss 0.9281108975410461
381 Epoch 90 iteration 200 loss 1.3821330070495605
382 Epoch 90 Training loss 1.5474867314671616
383 Evaluation loss 3.2276618163204667
384 Epoch 91 iteration 0 loss 1.3667511940002441
385 Epoch 91 iteration 100 loss 0.8797598481178284
386 Epoch 91 iteration 200 loss 1.3776274919509888
387 Epoch 91 Training loss 1.536482189982952
388 Epoch 92 iteration 0 loss 1.3355433940887451
389 Epoch 92 iteration 100 loss 0.9130176901817322
390 Epoch 92 iteration 200 loss 1.3042923212051392
391 Epoch 92 Training loss 1.5308507835779057
392 Epoch 93 iteration 0 loss 1.2953367233276367
393 Epoch 93 iteration 100 loss 0.9194003939628601
394 Epoch 93 iteration 200 loss 1.3469970226287842
395 Epoch 93 Training loss 1.519625581403501
396 Epoch 94 iteration 0 loss 1.322600245475769
397 Epoch 94 iteration 100 loss 0.9003701210021973
398 Epoch 94 iteration 200 loss 1.3512846231460571
399 Epoch 94 Training loss 1.5193673748787049
400 Epoch 95 iteration 0 loss 1.2789180278778076
401 Epoch 95 iteration 100 loss 0.9352515339851379
402 Epoch 95 iteration 200 loss 1.3609877824783325
403 Epoch 95 Training loss 1.5135782739054082
404 Evaluation loss 3.2474015759319284
405 Epoch 96 iteration 0 loss 1.3051612377166748
406 Epoch 96 iteration 100 loss 0.8885603547096252
407 Epoch 96 iteration 200 loss 1.3272497653961182
408 Epoch 96 Training loss 1.5079536183100883
409 Epoch 97 iteration 0 loss 1.2671339511871338
410 Epoch 97 iteration 100 loss 0.8706735968589783
411 Epoch 97 iteration 200 loss 1.305412769317627
412 Epoch 97 Training loss 1.4974833326540824
413 Epoch 98 iteration 0 loss 1.308292269706726
414 Epoch 98 iteration 100 loss 0.9079441428184509
415 Epoch 98 iteration 200 loss 1.2940715551376343
416 Epoch 98 Training loss 1.4928753682563118
417 Epoch 99 iteration 0 loss 1.276250958442688
418 Epoch 99 iteration 100 loss 0.890657901763916
419 Epoch 99 iteration 200 loss 1.3286609649658203
420 Epoch 99 Training loss 1.4852960116094391
View Code

6.翻译

 1 def translate_dev(i):
 2     en_sent = " ".join([inv_en_dict[w] for w in dev_en[i]])  #原来的英文
 3     print(en_sent)
 4     cn_sent = " ".join([inv_cn_dict[w] for w in dev_cn[i]])  #原来的中文
 5     print("".join(cn_sent))
 6 
 7     mb_x = torch.from_numpy(np.array(dev_en[i]).reshape(1, -1)).long().to(device)
 8     mb_x_len = torch.from_numpy(np.array([len(dev_en[i])])).long().to(device)
 9     bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)
10 
11     translation, attn = model.translate(mb_x, mb_x_len, bos)
12     translation = [inv_cn_dict[i] for i in translation.data.cpu().numpy().reshape(-1)]
13     trans = []
14     for word in translation:
15         if word != "EOS":
16             trans.append(word)
17         else:
18             break
19     print("".join(trans))           #翻译后的中文
20 
21 for i in range(100, 120):
22     translate_dev(i)
23     print()

执行结果如下(样本少加上训练时间太短造成翻译效果不太好):

 1 BOS you have nice skin . EOS
 2 BOS 你 的 皮 膚 真 好 。 EOS
 3 你有很多好。
 4 
 5 BOS you re UNK correct . EOS
 6 BOS 你 部 分 正 确 。 EOS
 7 你是个好人。
 8 
 9 BOS everyone admired his courage . EOS
10 BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS
11 他們的電話讓他們一個
12 
13 BOS what time is it ? EOS
14 BOS 几 点 了 ? EOS
15 它还是什么?
16 
17 BOS i m free tonight . EOS
18 BOS 我 今 晚 有 空 。 EOS
19 我今晚有空。
20 
21 BOS here is your book . EOS
22 BOS 這 是 你 的 書 。 EOS
23 你的書桌是舊。
24 
25 BOS they are at lunch . EOS
26 BOS 他 们 在 吃 午 饭 。 EOS
27 他们正在吃米饭。
28 
29 BOS this chair is UNK . EOS
30 BOS 這 把 椅 子 很 UNK 。 EOS
31 這是真的最好的人。
32 
33 BOS it s pretty heavy . EOS
34 BOS 它 真 重 。 EOS
35 它是真的。
36 
37 BOS many attended his funeral . EOS
38 BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS
39 AI他的襪子。
40 
41 BOS training will be provided . EOS
42 BOS 会 有 训 练 。 EOS
43 人们的货品造成的。
44 
45 BOS someone is watching you . EOS
46 BOS 有 人 在 看 著 你 。 EOS
47 有人叫醒汤姆。
48 
49 BOS i slapped his face . EOS
50 BOS 我 摑 了 他 的 臉 。 EOS
51 我有他的兄弟。
52 
53 BOS i like UNK music . EOS
54 BOS 我 喜 歡 流 行 音 樂 。 EOS
55 我喜欢狗在家。
56 
57 BOS tom had no children . EOS
58 BOS T o m 沒 有 孩 子 。 EOS
59 汤姆不需要做什么。
60 
61 BOS please lock the door . EOS
62 BOS 請 把 門 鎖 上 。 EOS
63 請把門打開。
64 
65 BOS tom has calmed down . EOS
66 BOS 汤 姆 冷 静 下 来 了 。 EOS
67 汤姆睡着了。
68 
69 BOS please speak more loudly . EOS
70 BOS 請 說 大 聲 一 點 兒 。 EOS
71 請講慢一點。
72 
73 BOS keep next sunday free . EOS
74 BOS 把 下 周 日 空 出 来 。 EOS
75 下午可以轉下。
76 
77 BOS i made a mistake . EOS
78 BOS 我 犯 了 一 個 錯 。 EOS
79 我有些意生。

7.Encoder Decoder模型(含Attention版本)

7.1Encoder

Encoder模型的任务是把输入文字传入embedding层和GRU层,转换成一些hidden states作为后续的context vectors;

 1 class Encoder(nn.Module):
 2     def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
 3         super(Encoder, self).__init__()
 4         self.embed = nn.Embedding(vocab_size, embed_size)
 5         self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True)
 6         self.dropout = nn.Dropout(dropout)
 7         self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size)
 8 
 9     def forward(self, x, lengths):
10         sorted_len, sorted_idx = lengths.sort(0, descending=True)
11         x_sorted = x[sorted_idx.long()]
12         embedded = self.dropout(self.embed(x_sorted))
13         
14         packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
15         packed_out, hid = self.rnn(packed_embedded)
16         out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
17         _, original_idx = sorted_idx.sort(0, descending=False)
18         out = out[original_idx.long()].contiguous()
19         hid = hid[:, original_idx.long()].contiguous()
20         
21         hid = torch.cat([hid[-2], hid[-1]], dim=1)          #双向,所以拼接
22         hid = torch.tanh(self.fc(hid)).unsqueeze(0) 
23 
24         return out, hid

7.2Luong Attention

技术图片

图中ht是output,hs是context,计算score时使用第二种方法;

根据context vectors和当前的输出hidden states,计算输出;

 1 class Attention(nn.Module):
 2     def __init__(self, enc_hidden_size, dec_hidden_size):
 3         super(Attention, self).__init__()
 4 
 5         self.enc_hidden_size = enc_hidden_size
 6         self.dec_hidden_size = dec_hidden_size
 7 
 8         self.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False)   #线性变换
 9         self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size)
10 
11     def forward(self, output, context, mask):
12         # output: [batch_size, output_len, dec_hidden_size]
13         # context: [batch_size, context_len, 2*enc_hidden_size]
14 
15         batch_size = output.size(0)
16         output_len = output.size(1)
17         context_len = context.size(1)
18 
19         #context_in: [batch_size, context_len, dec_hidden_size]
20         context_in = self.linear_in(context.view(batch_size*context_len, -1)).view(batch_size, context_len, -1)
21 
22         # context_in.transpose(1,2): [batch_size, dec_hidden_size, context_len]
23         attn = torch.bmm(output, context_in.transpose(1,2))       #attn: [batch_size, output_len, context_len]
24 
25         attn.data.masked_fill(mask, -1e6)
26 
27         attn = F.softmax(attn, dim=2)                             #attn: [batch_size, output_len, context_len]
28 
29         context = torch.bmm(attn, context)                        #context: [batch_size, output_len, enc_hidden_size]
30 
31         output = torch.cat((context, output), dim=2)              #output: [batch_size, output_len, hidden_size*2]
32 
33         output = output.view(batch_size*output_len, -1)
34         output = torch.tanh(self.linear_out(output))
35         output = output.view(batch_size, output_len, -1)
36 
37         return output, attn

7.3Decoder

 Decoder会根据已经翻译的句子内容和context vectors,来决定下一个输出的单词;

 1 class Decoder(nn.Module):
 2     def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
 3         super(Decoder, self).__init__()
 4         self.embed = nn.Embedding(vocab_size, embed_size)
 5         self.attention = Attention(enc_hidden_size, dec_hidden_size)
 6         self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
 7         self.out = nn.Linear(dec_hidden_size, vocab_size)
 8         self.dropout = nn.Dropout(dropout)
 9 
10     def create_mask(self, x_len, y_len):  # a mask of shape x_len * y_len
11         
12         device = x_len.device
13         max_x_len = x_len.max()
14         max_y_len = y_len.max()
15         x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None]
16         y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None]
17         mask = (1 - x_mask[:, :, None] * y_mask[:, None, :]).byte()
18         return mask
19     
20     def forward(self, ctx, ctx_lengths, y, y_lengths, hid):
21         sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
22         y_sorted = y[sorted_idx.long()]
23         hid = hid[:, sorted_idx.long()]
24         
25         y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_size
26 
27         packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
28         out, hid = self.rnn(packed_seq, hid)
29         unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
30         _, original_idx = sorted_idx.sort(0, descending=False)
31         output_seq = unpacked[original_idx.long()].contiguous()
32         hid = hid[:, original_idx.long()].contiguous()
33 
34         mask = self.create_mask(y_lengths, ctx_lengths)
35 
36         output, attn = self.attention(output_seq, ctx, mask)  #根据原来的output_seq和context来计算
37         output = F.log_softmax(self.out(output), -1)
38         
39         return output, hid, attn

7.4训练函数并调用上面的train函数

 1 dropout = 0.2
 2 embed_size = hidden_size = 100
 3 encoder = Encoder(vocab_size=en_total_words, embed_size=embed_size, enc_hidden_size=hidden_size, dec_hidden_size=hidden_size, dropout=dropout)
 4 decoder = Decoder(vocab_size=cn_total_words, embed_size=embed_size, enc_hidden_size=hidden_size, dec_hidden_size=hidden_size, dropout=dropout)
 5 model = Seq2Seq(encoder, decoder)
 6 model = model.to(device)
 7 loss_fn = LanguageModelCriterion().to(device)
 8 optimizer = torch.optim.Adam(model.parameters())
 9 
10 train(model, train_data, num_epochs=100)

训练结果:

技术图片
  1 Epoch 0 iteration 0 loss 1.3026132583618164
  2 Epoch 0 iteration 100 loss 0.8847191333770752
  3 Epoch 0 iteration 200 loss 1.285671353340149
  4 Epoch 0 Training loss 1.4803871257447
  5 Evaluation loss 3.260634059314127
  6 Epoch 1 iteration 0 loss 1.2434465885162354
  7 Epoch 1 iteration 100 loss 0.8472312092781067
  8 Epoch 1 iteration 200 loss 1.282746434211731
  9 Epoch 1 Training loss 1.4731217075462495
 10 Epoch 2 iteration 0 loss 1.2593930959701538
 11 Epoch 2 iteration 100 loss 0.8484001159667969
 12 Epoch 2 iteration 200 loss 1.2862968444824219
 13 Epoch 2 Training loss 1.466728115795638
 14 Epoch 3 iteration 0 loss 1.2554501295089722
 15 Epoch 3 iteration 100 loss 0.9115875363349915
 16 Epoch 3 iteration 200 loss 1.2563236951828003
 17 Epoch 3 Training loss 1.4607854827943827
 18 Epoch 4 iteration 0 loss 1.217956304550171
 19 Epoch 4 iteration 100 loss 0.8641748428344727
 20 Epoch 4 iteration 200 loss 1.2998305559158325
 21 Epoch 4 Training loss 1.4587181747145395
 22 Epoch 5 iteration 0 loss 1.258739709854126
 23 Epoch 5 iteration 100 loss 0.8705984354019165
 24 Epoch 5 iteration 200 loss 1.2102816104888916
 25 Epoch 5 Training loss 1.4507371513452623
 26 Evaluation loss 3.266629208261664
 27 Epoch 6 iteration 0 loss 1.259811282157898
 28 Epoch 6 iteration 100 loss 0.8492067456245422
 29 Epoch 6 iteration 200 loss 1.3064922094345093
 30 Epoch 6 Training loss 1.4432446458560053
 31 Epoch 7 iteration 0 loss 1.2411160469055176
 32 Epoch 7 iteration 100 loss 0.8373231291770935
 33 Epoch 7 iteration 200 loss 1.2500189542770386
 34 Epoch 7 Training loss 1.436364381060567
 35 Epoch 8 iteration 0 loss 1.1868956089019775
 36 Epoch 8 iteration 100 loss 0.814584493637085
 37 Epoch 8 iteration 200 loss 1.2773609161376953
 38 Epoch 8 Training loss 1.4354508132900903
 39 Epoch 9 iteration 0 loss 1.2234464883804321
 40 Epoch 9 iteration 100 loss 0.797888457775116
 41 Epoch 9 iteration 200 loss 1.2435855865478516
 42 Epoch 9 Training loss 1.424914875345232
 43 Epoch 10 iteration 0 loss 1.2067270278930664
 44 Epoch 10 iteration 100 loss 0.8425077795982361
 45 Epoch 10 iteration 200 loss 1.2325958013534546
 46 Epoch 10 Training loss 1.4212906077384722
 47 Evaluation loss 3.2876189393276327
 48 Epoch 11 iteration 0 loss 1.221406102180481
 49 Epoch 11 iteration 100 loss 0.80806964635849
 50 Epoch 11 iteration 200 loss 1.3028448820114136
 51 Epoch 11 Training loss 1.4154276829998698
 52 Epoch 12 iteration 0 loss 1.1890984773635864
 53 Epoch 12 iteration 100 loss 0.827181875705719
 54 Epoch 12 iteration 200 loss 1.1675362586975098
 55 Epoch 12 Training loss 1.4132606964483012
 56 Epoch 13 iteration 0 loss 1.2002121210098267
 57 Epoch 13 iteration 100 loss 0.8232781291007996
 58 Epoch 13 iteration 200 loss 1.2605061531066895
 59 Epoch 13 Training loss 1.407515715564216
 60 Epoch 14 iteration 0 loss 1.1855664253234863
 61 Epoch 14 iteration 100 loss 0.8178666234016418
 62 Epoch 14 iteration 200 loss 1.2378345727920532
 63 Epoch 14 Training loss 1.3966619677770713
 64 Epoch 15 iteration 0 loss 1.1885008811950684
 65 Epoch 15 iteration 100 loss 0.7523401975631714
 66 Epoch 15 iteration 200 loss 1.1757400035858154
 67 Epoch 15 Training loss 1.3940533722612007
 68 Evaluation loss 3.3011410674061716
 69 Epoch 16 iteration 0 loss 1.185882806777954
 70 Epoch 16 iteration 100 loss 0.8129084706306458
 71 Epoch 16 iteration 200 loss 1.2022055387496948
 72 Epoch 16 Training loss 1.3908352825348185
 73 Epoch 17 iteration 0 loss 1.145820140838623
 74 Epoch 17 iteration 100 loss 0.7933529615402222
 75 Epoch 17 iteration 200 loss 1.1954973936080933
 76 Epoch 17 Training loss 1.3862186002415022
 77 Epoch 18 iteration 0 loss 1.1626101732254028
 78 Epoch 18 iteration 100 loss 0.8041335940361023
 79 Epoch 18 iteration 200 loss 1.1879560947418213
 80 Epoch 18 Training loss 1.3828558699833502
 81 Epoch 19 iteration 0 loss 1.1661605834960938
 82 Epoch 19 iteration 100 loss 0.7746578454971313
 83 Epoch 19 iteration 200 loss 1.167975902557373
 84 Epoch 19 Training loss 1.3737090146397222
 85 Epoch 20 iteration 0 loss 1.1992604732513428
 86 Epoch 20 iteration 100 loss 0.7750277519226074
 87 Epoch 20 iteration 200 loss 1.1533249616622925
 88 Epoch 20 Training loss 1.3699699581049805
 89 Evaluation loss 3.316624780553762
 90 Epoch 21 iteration 0 loss 1.182730793952942
 91 Epoch 21 iteration 100 loss 0.7664387822151184
 92 Epoch 21 iteration 200 loss 1.1734970808029175
 93 Epoch 21 Training loss 1.3634166858854262
 94 Epoch 22 iteration 0 loss 1.1587318181991577
 95 Epoch 22 iteration 100 loss 0.7660608291625977
 96 Epoch 22 iteration 200 loss 1.1832681894302368
 97 Epoch 22 Training loss 1.3601878647219552
 98 Epoch 23 iteration 0 loss 1.123557209968567
 99 Epoch 23 iteration 100 loss 0.7884796857833862
100 Epoch 23 iteration 200 loss 1.131569266319275
101 Epoch 23 Training loss 1.3543664767232568
102 Epoch 24 iteration 0 loss 1.1566004753112793
103 Epoch 24 iteration 100 loss 0.7894638180732727
104 Epoch 24 iteration 200 loss 1.1293442249298096
105 Epoch 24 Training loss 1.3513351050205646
106 Epoch 25 iteration 0 loss 1.1237646341323853
107 Epoch 25 iteration 100 loss 0.7442751526832581
108 Epoch 25 iteration 200 loss 1.1396199464797974
109 Epoch 25 Training loss 1.3436930389138495
110 Evaluation loss 3.331576243054354
111 Epoch 26 iteration 0 loss 1.1391510963439941
112 Epoch 26 iteration 100 loss 0.7658866047859192
113 Epoch 26 iteration 200 loss 1.130005121231079
114 Epoch 26 Training loss 1.3387227258896204
115 Epoch 27 iteration 0 loss 1.086417555809021
116 Epoch 27 iteration 100 loss 0.7512990236282349
117 Epoch 27 iteration 200 loss 1.1055928468704224
118 Epoch 27 Training loss 1.3332018813254016
119 Epoch 28 iteration 0 loss 1.1308163404464722
120 Epoch 28 iteration 100 loss 0.7653459310531616
121 Epoch 28 iteration 200 loss 1.1437530517578125
122 Epoch 28 Training loss 1.3318316266073582
123 Epoch 29 iteration 0 loss 1.1284910440444946
124 Epoch 29 iteration 100 loss 0.7385256886482239
125 Epoch 29 iteration 200 loss 1.076254963874817
126 Epoch 29 Training loss 1.327704983812702
127 Epoch 30 iteration 0 loss 1.1279666423797607
128 Epoch 30 iteration 100 loss 0.7510428428649902
129 Epoch 30 iteration 200 loss 1.10474693775177
130 Epoch 30 Training loss 1.3247037412152105
131 Evaluation loss 3.345638094832775
132 Epoch 31 iteration 0 loss 1.1144018173217773
133 Epoch 31 iteration 100 loss 0.7183322906494141
134 Epoch 31 iteration 200 loss 1.1657849550247192
135 Epoch 31 Training loss 1.3181022928511037
136 Epoch 32 iteration 0 loss 1.1624877452850342
137 Epoch 32 iteration 100 loss 0.6971022486686707
138 Epoch 32 iteration 200 loss 1.1033793687820435
139 Epoch 32 Training loss 1.313637083400949
140 Epoch 33 iteration 0 loss 1.0961930751800537
141 Epoch 33 iteration 100 loss 0.7509954571723938
142 Epoch 33 iteration 200 loss 1.0901885032653809
143 Epoch 33 Training loss 1.3105013603183797
144 Epoch 34 iteration 0 loss 1.0936028957366943
145 Epoch 34 iteration 100 loss 0.7300226092338562
146 Epoch 34 iteration 200 loss 1.094140648841858
147 Epoch 34 Training loss 1.3085180236466905
148 Epoch 35 iteration 0 loss 1.1358038187026978
149 Epoch 35 iteration 100 loss 0.6928472518920898
150 Epoch 35 iteration 200 loss 1.1031907796859741
151 Epoch 35 Training loss 1.2983291715098229
152 Evaluation loss 3.3654267819449917
153 Epoch 36 iteration 0 loss 1.0817443132400513
154 Epoch 36 iteration 100 loss 0.7034777998924255
155 Epoch 36 iteration 200 loss 1.1244701147079468
156 Epoch 36 Training loss 1.294685624884497
157 Epoch 37 iteration 0 loss 1.0067986249923706
158 Epoch 37 iteration 100 loss 0.6711763739585876
159 Epoch 37 iteration 200 loss 1.0877138376235962
160 Epoch 37 Training loss 1.2908666166705178
161 Epoch 38 iteration 0 loss 1.0796058177947998
162 Epoch 38 iteration 100 loss 0.6984289288520813
163 Epoch 38 iteration 200 loss 1.0992212295532227
164 Epoch 38 Training loss 1.289693898836594
165 Epoch 39 iteration 0 loss 1.1193760633468628
166 Epoch 39 iteration 100 loss 0.7441080212593079
167 Epoch 39 iteration 200 loss 1.0557031631469727
168 Epoch 39 Training loss 1.287817969907393
169 Epoch 40 iteration 0 loss 1.0878312587738037
170 Epoch 40 iteration 100 loss 0.7390894889831543
171 Epoch 40 iteration 200 loss 1.0931909084320068
172 Epoch 40 Training loss 1.281862247889987
173 Evaluation loss 3.3775776429435034
174 Epoch 41 iteration 0 loss 1.1135987043380737
175 Epoch 41 iteration 100 loss 0.6786257028579712
176 Epoch 41 iteration 200 loss 1.056801676750183
177 Epoch 41 Training loss 1.2759448500226234
178 Epoch 42 iteration 0 loss 1.0649049282073975
179 Epoch 42 iteration 100 loss 0.6774815320968628
180 Epoch 42 iteration 200 loss 1.0807018280029297
181 Epoch 42 Training loss 1.2713608683004023
182 Epoch 43 iteration 0 loss 1.0711919069290161
183 Epoch 43 iteration 100 loss 0.6655244827270508
184 Epoch 43 iteration 200 loss 1.0616692304611206
185 Epoch 43 Training loss 1.2693709204800718
186 Epoch 44 iteration 0 loss 1.0423146486282349
187 Epoch 44 iteration 100 loss 0.7055337429046631
188 Epoch 44 iteration 200 loss 1.0746649503707886
189 Epoch 44 Training loss 1.2649514066760854
190 Epoch 45 iteration 0 loss 1.0937353372573853
191 Epoch 45 iteration 100 loss 0.6939021348953247
192 Epoch 45 iteration 200 loss 1.1060905456542969
193 Epoch 45 Training loss 1.2591645727085945
194 Evaluation loss 3.393269126438251
195 Epoch 46 iteration 0 loss 1.1005926132202148
196 Epoch 46 iteration 100 loss 0.6948174238204956
197 Epoch 46 iteration 200 loss 1.0675958395004272
198 Epoch 46 Training loss 1.2555609077507983
199 Epoch 47 iteration 0 loss 1.0566778182983398
200 Epoch 47 iteration 100 loss 0.6904436349868774
201 Epoch 47 iteration 200 loss 1.0723766088485718
202 Epoch 47 Training loss 1.2552127211907091
203 Epoch 48 iteration 0 loss 1.0497757196426392
204 Epoch 48 iteration 100 loss 0.6351101398468018
205 Epoch 48 iteration 200 loss 1.0661102533340454
206 Epoch 48 Training loss 1.2479233140629313
207 Epoch 49 iteration 0 loss 1.0470858812332153
208 Epoch 49 iteration 100 loss 0.6707669496536255
209 Epoch 49 iteration 200 loss 1.063056230545044
210 Epoch 49 Training loss 1.2453716254928995
211 Epoch 50 iteration 0 loss 1.0854698419570923
212 Epoch 50 iteration 100 loss 0.6165581345558167
213 Epoch 50 iteration 200 loss 1.0804699659347534
214 Epoch 50 Training loss 1.243395479327141
215 Evaluation loss 3.4102849531750494
216 Epoch 51 iteration 0 loss 1.0279858112335205
217 Epoch 51 iteration 100 loss 0.6448107957839966
218 Epoch 51 iteration 200 loss 1.0390673875808716
219 Epoch 51 Training loss 1.2358123475125082
220 Epoch 52 iteration 0 loss 1.0429105758666992
221 Epoch 52 iteration 100 loss 0.7124451994895935
222 Epoch 52 iteration 200 loss 1.061672329902649
223 Epoch 52 Training loss 1.233326693902576
224 Epoch 53 iteration 0 loss 1.0357102155685425
225 Epoch 53 iteration 100 loss 0.6381393074989319
226 Epoch 53 iteration 200 loss 1.0036036968231201
227 Epoch 53 Training loss 1.2297246983027847
228 Epoch 54 iteration 0 loss 1.0590764284133911
229 Epoch 54 iteration 100 loss 0.6603832840919495
230 Epoch 54 iteration 200 loss 1.0215944051742554
231 Epoch 54 Training loss 1.227340017322883
232 Epoch 55 iteration 0 loss 1.0460106134414673
233 Epoch 55 iteration 100 loss 0.67122882604599
234 Epoch 55 iteration 200 loss 1.0344772338867188
235 Epoch 55 Training loss 1.2244369935263697
236 Evaluation loss 3.4193579365732183
237 Epoch 56 iteration 0 loss 1.032409429550171
238 Epoch 56 iteration 100 loss 0.6183319091796875
239 Epoch 56 iteration 200 loss 0.9782896637916565
240 Epoch 56 Training loss 1.2178285635372972
241 Epoch 57 iteration 0 loss 1.0382548570632935
242 Epoch 57 iteration 100 loss 0.6902874708175659
243 Epoch 57 iteration 200 loss 1.016508936882019
244 Epoch 57 Training loss 1.214647328633978
245 Epoch 58 iteration 0 loss 1.0595533847808838
246 Epoch 58 iteration 100 loss 0.6885846853256226
247 Epoch 58 iteration 200 loss 1.0221766233444214
248 Epoch 58 Training loss 1.2100419675457097
249 Epoch 59 iteration 0 loss 1.014621615409851
250 Epoch 59 iteration 100 loss 0.602800190448761
251 Epoch 59 iteration 200 loss 1.037442684173584
252 Epoch 59 Training loss 1.2131489746903632
253 Epoch 60 iteration 0 loss 1.0217640399932861
254 Epoch 60 iteration 100 loss 0.6246439814567566
255 Epoch 60 iteration 200 loss 1.00297212600708
256 Epoch 60 Training loss 1.204652290841725
257 Evaluation loss 3.4292290158399075
258 Epoch 61 iteration 0 loss 0.9992070198059082
259 Epoch 61 iteration 100 loss 0.645142138004303
260 Epoch 61 iteration 200 loss 0.9961024522781372
261 Epoch 61 Training loss 1.2066241352823563
262 Epoch 62 iteration 0 loss 0.9980950951576233
263 Epoch 62 iteration 100 loss 0.6504135131835938
264 Epoch 62 iteration 200 loss 1.000308632850647
265 Epoch 62 Training loss 1.1984729866171178
266 Epoch 63 iteration 0 loss 0.9869410395622253
267 Epoch 63 iteration 100 loss 0.6618863344192505
268 Epoch 63 iteration 200 loss 0.981200635433197
269 Epoch 63 Training loss 1.1967191445048035
270 Epoch 64 iteration 0 loss 0.9695953130722046
271 Epoch 64 iteration 100 loss 0.6359274387359619
272 Epoch 64 iteration 200 loss 0.9904515743255615
273 Epoch 64 Training loss 1.194521029171779
274 Epoch 65 iteration 0 loss 0.9505796432495117
275 Epoch 65 iteration 100 loss 0.6068794131278992
276 Epoch 65 iteration 200 loss 0.980348527431488
277 Epoch 65 Training loss 1.189270519480765
278 Evaluation loss 3.454153442993674
279 Epoch 66 iteration 0 loss 1.0304545164108276
280 Epoch 66 iteration 100 loss 0.6792566776275635
281 Epoch 66 iteration 200 loss 0.9789241552352905
282 Epoch 66 Training loss 1.1869953296382767
283 Epoch 67 iteration 0 loss 0.957666277885437
284 Epoch 67 iteration 100 loss 0.584879994392395
285 Epoch 67 iteration 200 loss 1.0174148082733154
286 Epoch 67 Training loss 1.184179561090835
287 Epoch 68 iteration 0 loss 1.043166995048523
288 Epoch 68 iteration 100 loss 0.6168758869171143
289 Epoch 68 iteration 200 loss 1.0030053853988647
290 Epoch 68 Training loss 1.1824355462851552
291 Epoch 69 iteration 0 loss 1.0165300369262695
292 Epoch 69 iteration 100 loss 0.6542645692825317
293 Epoch 69 iteration 200 loss 1.0191236734390259
294 Epoch 69 Training loss 1.176021675397731
295 Epoch 70 iteration 0 loss 0.9590736031532288
296 Epoch 70 iteration 100 loss 0.6157773733139038
297 Epoch 70 iteration 200 loss 1.0451829433441162
298 Epoch 70 Training loss 1.1732503092442255
299 Evaluation loss 3.4715566423642277
300 Epoch 71 iteration 0 loss 0.971733570098877
301 Epoch 71 iteration 100 loss 0.5589802265167236
302 Epoch 71 iteration 200 loss 1.0018212795257568
303 Epoch 71 Training loss 1.1694891346833023
304 Epoch 72 iteration 0 loss 1.0042874813079834
305 Epoch 72 iteration 100 loss 0.6543828248977661
306 Epoch 72 iteration 200 loss 0.968835175037384
307 Epoch 72 Training loss 1.1667191714442264
308 Epoch 73 iteration 0 loss 0.9512341022491455
309 Epoch 73 iteration 100 loss 0.5809782147407532
310 Epoch 73 iteration 200 loss 0.9460022449493408
311 Epoch 73 Training loss 1.165780372824424
312 Epoch 74 iteration 0 loss 0.9838390946388245
313 Epoch 74 iteration 100 loss 0.6115572452545166
314 Epoch 74 iteration 200 loss 0.9821975827217102
315 Epoch 74 Training loss 1.1619031632185661
316 Epoch 75 iteration 0 loss 0.9615085124969482
317 Epoch 75 iteration 100 loss 0.5715279579162598
318 Epoch 75 iteration 200 loss 0.9673617482185364
319 Epoch 75 Training loss 1.1592393025041507
320 Evaluation loss 3.4810480487503015
321 Epoch 76 iteration 0 loss 0.9920525550842285
322 Epoch 76 iteration 100 loss 0.6243174076080322
323 Epoch 76 iteration 200 loss 0.9598985910415649
324 Epoch 76 Training loss 1.1506768550866349
325 Epoch 77 iteration 0 loss 0.9717826843261719
326 Epoch 77 iteration 100 loss 0.5903583765029907
327 Epoch 77 iteration 200 loss 0.9472079873085022
328 Epoch 77 Training loss 1.151601228059984
329 Epoch 78 iteration 0 loss 0.9331899881362915
330 Epoch 78 iteration 100 loss 0.6189018487930298
331 Epoch 78 iteration 200 loss 0.9951513409614563
332 Epoch 78 Training loss 1.1474281610158772
333 Epoch 79 iteration 0 loss 0.9012037515640259
334 Epoch 79 iteration 100 loss 0.5837778449058533
335 Epoch 79 iteration 200 loss 0.9066386818885803
336 Epoch 79 Training loss 1.142656700489289
337 Epoch 80 iteration 0 loss 0.9931736588478088
338 Epoch 80 iteration 100 loss 0.5927265882492065
339 Epoch 80 iteration 200 loss 0.938447892665863
340 Epoch 80 Training loss 1.1434640075302192
341 Evaluation loss 3.491394963812294
342 Epoch 81 iteration 0 loss 0.9227023720741272
343 Epoch 81 iteration 100 loss 0.5467157363891602
344 Epoch 81 iteration 200 loss 0.9126712083816528
345 Epoch 81 Training loss 1.1427882346320761
346 Epoch 82 iteration 0 loss 0.9733406901359558
347 Epoch 82 iteration 100 loss 0.564643144607544
348 Epoch 82 iteration 200 loss 0.9918593764305115
349 Epoch 82 Training loss 1.1362837826371996
350 Epoch 83 iteration 0 loss 0.9489978551864624
351 Epoch 83 iteration 100 loss 0.5791521668434143
352 Epoch 83 iteration 200 loss 0.9270768165588379
353 Epoch 83 Training loss 1.136011173156159
354 Epoch 84 iteration 0 loss 0.9410436749458313
355 Epoch 84 iteration 100 loss 0.5409624576568604
356 Epoch 84 iteration 200 loss 0.8918321132659912
357 Epoch 84 Training loss 1.128506034776273
358 Epoch 85 iteration 0 loss 0.9554007053375244
359 Epoch 85 iteration 100 loss 0.571331799030304
360 Epoch 85 iteration 200 loss 0.9672144055366516
361 Epoch 85 Training loss 1.1277133646535586
362 Evaluation loss 3.5075310684848158
363 Epoch 86 iteration 0 loss 0.9104467630386353
364 Epoch 86 iteration 100 loss 0.5656437277793884
365 Epoch 86 iteration 200 loss 0.9324206113815308
366 Epoch 86 Training loss 1.126375005188875
367 Epoch 87 iteration 0 loss 0.9339620471000671
368 Epoch 87 iteration 100 loss 0.5636867880821228
369 Epoch 87 iteration 200 loss 0.8825109601020813
370 Epoch 87 Training loss 1.1222316938253494
371 Epoch 88 iteration 0 loss 0.904504120349884
372 Epoch 88 iteration 100 loss 0.5706378221511841
373 Epoch 88 iteration 200 loss 0.9415532350540161
374 Epoch 88 Training loss 1.1215731092872845
375 Epoch 89 iteration 0 loss 0.9489354491233826
376 Epoch 89 iteration 100 loss 0.6389216184616089
377 Epoch 89 iteration 200 loss 0.8783397078514099
378 Epoch 89 Training loss 1.1199876689692458
379 Epoch 90 iteration 0 loss 0.909376323223114
380 Epoch 90 iteration 100 loss 0.6190019249916077
381 Epoch 90 iteration 200 loss 0.9191233515739441
382 Epoch 90 Training loss 1.1181392741798546
383 Evaluation loss 3.508678977926201
384 Epoch 91 iteration 0 loss 0.9080389738082886
385 Epoch 91 iteration 100 loss 0.5580074191093445
386 Epoch 91 iteration 200 loss 0.9494779706001282
387 Epoch 91 Training loss 1.1147855064570311
388 Epoch 92 iteration 0 loss 0.900802731513977
389 Epoch 92 iteration 100 loss 0.573580801486969
390 Epoch 92 iteration 200 loss 0.9199456572532654
391 Epoch 92 Training loss 1.1107969536786537
392 Epoch 93 iteration 0 loss 0.9345868229866028
393 Epoch 93 iteration 100 loss 0.5590959787368774
394 Epoch 93 iteration 200 loss 0.90354984998703
395 Epoch 93 Training loss 1.105984925602608
396 Epoch 94 iteration 0 loss 0.9008861780166626
397 Epoch 94 iteration 100 loss 0.5503742098808289
398 Epoch 94 iteration 200 loss 0.8791723251342773
399 Epoch 94 Training loss 1.1053885063813342
400 Epoch 95 iteration 0 loss 0.899246096611023
401 Epoch 95 iteration 100 loss 0.6236768364906311
402 Epoch 95 iteration 200 loss 0.8661567568778992
403 Epoch 95 Training loss 1.0993307278503
404 Evaluation loss 3.5332032706941585
405 Epoch 96 iteration 0 loss 0.8837733864784241
406 Epoch 96 iteration 100 loss 0.5473974943161011
407 Epoch 96 iteration 200 loss 0.9025910496711731
408 Epoch 96 Training loss 1.0998253373283113
409 Epoch 97 iteration 0 loss 0.922965407371521
410 Epoch 97 iteration 100 loss 0.5556969046592712
411 Epoch 97 iteration 200 loss 0.9027858972549438
412 Epoch 97 Training loss 1.096842199480861
413 Epoch 98 iteration 0 loss 0.8947715759277344
414 Epoch 98 iteration 100 loss 0.5312948822975159
415 Epoch 98 iteration 200 loss 0.9379984736442566
416 Epoch 98 Training loss 1.0949072217540066
417 Epoch 99 iteration 0 loss 0.8829227685928345
418 Epoch 99 iteration 100 loss 0.5451477766036987
419 Epoch 99 iteration 200 loss 0.8783729672431946
420 Epoch 99 Training loss 1.092216717956385
View Code

7.5调用上面的translate_dev函数

1 for i in range(100,120):
2     translate_dev(i)
3     print()

执行结果如下:

 1 BOS you have nice skin . EOS
 2 BOS 你 的 皮 膚 真 好 。 EOS
 3 你有足球的食物都好了
 4 
 5 BOS you re UNK correct . EOS
 6 BOS 你 部 分 正 确 。 EOS
 7 你是个好厨师。
 8 
 9 BOS everyone admired his courage . EOS
10 BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS
11 他們每個人都很好奇。
12 
13 BOS what time is it ? EOS
14 BOS 几 点 了 ? EOS
15 它是什麼?
16 
17 BOS i m free tonight . EOS
18 BOS 我 今 晚 有 空 。 EOS
19 我今晚沒有空。
20 
21 BOS here is your book . EOS
22 BOS 這 是 你 的 書 。 EOS
23 你這附書是讀書。
24 
25 BOS they are at lunch . EOS
26 BOS 他 们 在 吃 午 饭 。 EOS
27 他们在吃米饭。
28 
29 BOS this chair is UNK . EOS
30 BOS 這 把 椅 子 很 UNK 。 EOS
31 這是真的最好的。
32 
33 BOS it s pretty heavy . EOS
34 BOS 它 真 重 。 EOS
35 它很有。
36 
37 BOS many attended his funeral . EOS
38 BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS
39 仔细把他的罪恶着。
40 
41 BOS training will be provided . EOS
42 BOS 会 有 训 练 。 EOS
43 克林變得餓了。
44 
45 BOS someone is watching you . EOS
46 BOS 有 人 在 看 著 你 。 EOS
47 有人在信封信。
48 
49 BOS i slapped his face . EOS
50 BOS 我 摑 了 他 的 臉 。 EOS
51 我是他的兄弟。
52 
53 BOS i like UNK music . EOS
54 BOS 我 喜 歡 流 行 音 樂 。 EOS
55 我喜歡音樂。
56 
57 BOS tom had no children . EOS
58 BOS T o m 沒 有 孩 子 。 EOS
59 Tom沒有太累。
60 
61 BOS please lock the door . EOS
62 BOS 請 把 門 鎖 上 。 EOS
63 請關門。
64 
65 BOS tom has calmed down . EOS
66 BOS 汤 姆 冷 静 下 来 了 。 EOS
67 汤姆向伤极了。
68 
69 BOS please speak more loudly . EOS
70 BOS 請 說 大 聲 一 點 兒 。 EOS
71 請講話。
72 
73 BOS keep next sunday free . EOS
74 BOS 把 下 周 日 空 出 来 。 EOS
75 下周举一直流出席。
76 
77 BOS i made a mistake . EOS
78 BOS 我 犯 了 一 個 錯 。 EOS
79 我有一个梦意。

翻译结果依然一般。

 

Pytorch-seq2seq机器翻译模型(不含attention和含attention两个版本)

标签:分割   sum   target   model   数据   汉字   线性   lis   假设   

原文地址:https://www.cnblogs.com/cxq1126/p/13565961.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!