标签:分割 sum target model 数据 汉字 线性 lis 假设
由于语料短,训练时间也短,模型性能不好,以下演示过程。
语料链接:
数据格式如图(先英文,再空格,再繁体中文):
以下代码运行在Google Colab上。
导包:
1 import os 2 import sys 3 import math 4 from collections import Counter 5 import numpy as np 6 import random 7 8 import torch 9 import torch.nn as nn 10 import torch.nn.functional as F 11 12 import nltk 13 nltk.download(‘punkt‘)
1 def load_data(in_file): 2 cn = [] 3 en = [] 4 num_examples = 0 5 with open(in_file, ‘r‘) as f: 6 for line in f: 7 line = line.strip().split("\t") 8 9 en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"]) 10 cn.append(["BOS"] + [c for c in line[1]] + ["EOS"]) 11 return en, cn 12 13 train_file = "nmt/en-cn/train.txt" 14 dev_file = "nmt/en-cn/dev.txt" 15 train_en, train_cn = load_data(train_file) 16 dev_en, dev_cn = load_data(dev_file)
查看返回的数据内容:
1 print(dev_en[:2]) 2 print(dev_cn[:2])
[[‘BOS‘, ‘she‘, ‘put‘, ‘the‘, ‘magazine‘, ‘on‘, ‘the‘, ‘table‘, ‘.‘, ‘EOS‘], [‘BOS‘, ‘hey‘, ‘,‘, ‘what‘, ‘are‘, ‘you‘, ‘doing‘, ‘here‘, ‘?‘, ‘EOS‘]]
[[‘BOS‘, ‘她‘, ‘把‘, ‘雜‘, ‘誌‘, ‘放‘, ‘在‘, ‘桌‘, ‘上‘, ‘。‘, ‘EOS‘], [‘BOS‘, ‘嘿‘, ‘,‘, ‘你‘, ‘在‘, ‘這‘, ‘做‘, ‘什‘, ‘麼‘, ‘?‘, ‘EOS‘]]
1 UNK_IDX = 0 2 PAD_IDX = 1 3 def build_dict(sentences, max_words=50000): 4 word_count = Counter() 5 for sentence in sentences: 6 for s in sentence: 7 word_count[s] += 1 8 ls = word_count.most_common(max_words) 9 total_words = len(ls) + 2 10 word_dict = {w[0]: index+2 for index, w in enumerate(ls)} 11 word_dict["UNK"] = UNK_IDX 12 word_dict["PAD"] = PAD_IDX 13 return word_dict, total_words #total_words所有单词数,最大50002 14 15 en_dict, en_total_words = build_dict(train_en) 16 cn_dict, cn_total_words = build_dict(train_cn) 17 inv_en_dict = {v: k for k, v in en_dict.items()} #英文:索引到单词 18 inv_cn_dict = {v: k for k, v in cn_dict.items()} #中文:索引到字
sort_by_len=True的目的是为了使得一个batch中的句子长度差不多,所以按长度排序。
1 def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True): 2 3 length = len(en_sentences) 4 out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences] 5 out_cn_sentences = [[cn_dict.get(w, 0) for w in sent] for sent in cn_sentences] 6 7 # sort sentences by word lengths 8 def len_argsort(seq): 9 return sorted(range(len(seq)), key=lambda x: len(seq[x])) 10 11 # 把中文和英文按照同样的顺序排序 12 if sort_by_len: 13 sorted_index = len_argsort(out_en_sentences) 14 out_en_sentences = [out_en_sentences[i] for i in sorted_index] 15 out_cn_sentences = [out_cn_sentences[i] for i in sorted_index] 16 17 return out_en_sentences, out_cn_sentences 18 19 train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict) 20 dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)
查看返回的数据内容:
1 print(train_cn[2]) 2 print([inv_cn_dict[i] for i in train_cn[2]]) 3 print([inv_en_dict[i] for i in train_en[2]])
[2, 982, 2028, 8, 4, 3]
[‘BOS‘, ‘祝‘, ‘贺‘, ‘你‘, ‘。‘, ‘EOS‘]
[‘BOS‘, ‘congratulations‘, ‘!‘, ‘EOS‘]
1 def get_minibatches(n, minibatch_size, shuffle=True): #n是传进来的句子数 2 idx_list = np.arange(0, n, minibatch_size) #[0, 1, ..., n-1]按minibatch_size大小分割 3 if shuffle: 4 np.random.shuffle(idx_list) 5 minibatches = [] 6 for idx in idx_list: 7 minibatches.append(np.arange(idx, min(idx + minibatch_size, n))) 8 return minibatches
查看上面函数的功能:
1 get_minibatches(100, 15) 2 [array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]), 3 array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]), 4 array([75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]), 5 array([45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]), 6 array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]), 7 array([90, 91, 92, 93, 94, 95, 96, 97, 98, 99]), 8 array([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])]
1 def prepare_data(seqs): #seqs传入的是minibatches中的一个minibatch对应的batch_size个句子索引(嵌套列表),此处batch_size=64 2 3 lengths = [len(seq) for seq in seqs] 4 n_samples = len(seqs) 5 max_len = np.max(lengths) #batch_size个句子中最长句子长度 6 7 x = np.zeros((n_samples, max_len)).astype(‘int32‘) 8 x_lengths = np.array(lengths).astype("int32") 9 for idx, seq in enumerate(seqs): 10 x[idx, :lengths[idx]] = seq 11 return x, x_lengths 12 13 def gen_examples(en_sentences, cn_sentences, batch_size): 14 minibatches = get_minibatches(len(en_sentences), batch_size) 15 all_ex = [] 16 for minibatch in minibatches: 17 mb_en_sentences = [en_sentences[t] for t in minibatch] 18 mb_cn_sentences = [cn_sentences[t] for t in minibatch] 19 mb_x, mb_x_len = prepare_data(mb_en_sentences) 20 mb_y, mb_y_len = prepare_data(mb_cn_sentences) 21 all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len)) 22 return all_ex #返回内容依次是batch_size个英文句子索引,英文句子长度,中文句子索引,中文句子长度 23 24 batch_size = 64 25 train_data = gen_examples(train_en, train_cn, batch_size) 26 dev_data = gen_examples(dev_en, dev_cn, batch_size)
1 # masked cross entropy loss 2 class LanguageModelCriterion(nn.Module): 3 def __init__(self): 4 super(LanguageModelCriterion, self).__init__() 5 6 def forward(self, input, target, mask): #把mask的部分忽略掉 7 # input: (batch_size * seq_len) * vocab_size 8 input = input.contiguous().view(-1, input.size(2)) 9 # target: batch_size * 1 10 target = target.contiguous().view(-1, 1) 11 mask = mask.contiguous().view(-1, 1) 12 output = -input.gather(1, target) * mask 13 output = torch.sum(output) / torch.sum(mask) 14 15 return output
Encoder模型的任务是把输入文字传入embedding层和GRU层,转换成一些hidden states作为后续的context vectors;
对nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence的理解:http://www.mamicode.com/info-detail-2493083.html
1 class PlainEncoder(nn.Module): 2 def __init__(self, vocab_size, hidden_size, dropout=0.2): #假设embedding_size=hidden_size 3 super(PlainEncoder, self).__init__() 4 self.embed = nn.Embedding(vocab_size, hidden_size) 5 self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True) 6 self.dropout = nn.Dropout(dropout) 7 8 def forward(self, x, lengths): #最后一个hidden_state要取出来作为context vector,所以需要lengths 9 sorted_len, sorted_idx = lengths.sort(0, descending=True) #把batch里面的seq按照长度降序排列 10 x_sorted = x[sorted_idx.long()] 11 embedded = self.dropout(self.embed(x_sorted)) 12 13 #句子padding到一样长度的(真实句长会比padding的短),为了rnn时能取到真实长度的最后状态,先pack_padded_sequence进行处理 14 packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True) 15 packed_out, hid = self.rnn(packed_embedded) 16 out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) #回到padding长度 17 18 _, original_idx = sorted_idx.sort(0, descending=False) #排序回原来的样子 19 out = out[original_idx.long()].contiguous() 20 hid = hid[:, original_idx.long()].contiguous() 21 22 return out, hid[[-1]] #hid[[-1]]相当于out[:, -1]
Decoder会根据已经翻译的句子内容和context vectors,来决定下一个输出的单词;
1 class PlainDecoder(nn.Module): 2 def __init__(self, vocab_size, hidden_size, dropout=0.2): 3 super(PlainDecoder, self).__init__() 4 self.embed = nn.Embedding(vocab_size, hidden_size) 5 self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True) 6 self.fc = nn.Linear(hidden_size, vocab_size) 7 self.dropout = nn.Dropout(dropout) 8 9 def forward(self, y, y_lengths, hid): #和PlainEncoder的forward过程大致差不多,区别在于hidden_state不是0而是传入的 10 sorted_len, sorted_idx = y_lengths.sort(0, descending=True) 11 y_sorted = y[sorted_idx.long()] 12 hid = hid[:, sorted_idx.long()] 13 14 y_sorted = self.dropout(self.embed(y_sorted)) #[batch_size, y_lengths, embed_size=hidden_size] 15 16 packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True) 17 out, hid = self.rnn(packed_seq, hid) 18 unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) 19 20 _, original_idx = sorted_idx.sort(0, descending=False) 21 output_seq = unpacked[original_idx.long()].contiguous() #[batch_size, y_lengths, hidden_size] 22 hid = hid[:, original_idx.long()].contiguous() #[1, batch_size, hidden_size] 23 24 output = F.log_softmax(self.fc(output_seq), -1) #[batch_size, y_lengths, vocab_size]-> 25 26 return output, hid
构建Seq2Seq模型把encoder, attention, decoder串到一起;
1 class PlainSeq2Seq(nn.Module): 2 def __init__(self, encoder, decoder): 3 super(PlainSeq2Seq, self).__init__() 4 self.encoder = encoder 5 self.decoder = decoder 6 7 def forward(self, x, x_lengths, y, y_lengths): 8 encoder_out, hid = self.encoder(x, x_lengths) 9 output, hid = self.decoder(y, y_lengths, hid) 10 return output, None 11 12 def translate(self, x, x_lengths, y, max_length=10): 13 encoder_out, hid = self.encoder(x, x_lengths) 14 preds = [] 15 batch_size = x.shape[0] 16 attns = [] 17 for i in range(max_length): 18 output, hid = self.decoder(y=y, y_lengths=torch.ones(batch_size).long().to(y.device), hid=hid) 19 y = output.max(2)[1].view(batch_size, 1) 20 preds.append(y) 21 22 return torch.cat(preds, 1), None
定义模型、损失、优化器。
1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 2 dropout = 0.2 3 hidden_size = 100 4 encoder = PlainEncoder(vocab_size=en_total_words, hidden_size=hidden_size, dropout=dropout) 5 decoder = PlainDecoder(vocab_size=cn_total_words, hidden_size=hidden_size, dropout=dropout) 6 model = PlainSeq2Seq(encoder, decoder) 7 model = model.to(device) 8 loss_fn = LanguageModelCriterion().to(device) 9 optimizer = torch.optim.Adam(model.parameters())
1 def evaluate(model, data): 2 model.eval() 3 total_num_words = total_loss = 0. 4 with torch.no_grad(): 5 for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data): 6 mb_x = torch.from_numpy(mb_x).to(device).long() 7 mb_x_len = torch.from_numpy(mb_x_len).to(device).long() 8 mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long() 9 mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long() 10 mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long() 11 mb_y_len[mb_y_len<=0] = 1 12 13 mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len) 14 15 mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None] 16 mb_out_mask = mb_out_mask.float() 17 18 loss = loss_fn(mb_pred, mb_output, mb_out_mask) 19 20 num_words = torch.sum(mb_y_len).item() 21 total_loss += loss.item() * num_words 22 total_num_words += num_words 23 print("Evaluation loss", total_loss/total_num_words)
1 def train(model, data, num_epochs=20): 2 for epoch in range(num_epochs): 3 model.train() 4 total_num_words = total_loss = 0. 5 for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data): 6 mb_x = torch.from_numpy(mb_x).to(device).long() 7 mb_x_len = torch.from_numpy(mb_x_len).to(device).long() 8 mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long() 9 mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long() 10 mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long() 11 mb_y_len[mb_y_len<=0] = 1 12 13 mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len) 14 15 mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None] 16 mb_out_mask = mb_out_mask.float() 17 18 loss = loss_fn(mb_pred, mb_output, mb_out_mask) 19 20 num_words = torch.sum(mb_y_len).item() 21 total_loss += loss.item() * num_words 22 total_num_words += num_words 23 24 # 更新模型 25 optimizer.zero_grad() 26 loss.backward() 27 torch.nn.utils.clip_grad_norm_(model.parameters(), 5.) 28 optimizer.step() 29 30 if it % 100 == 0: 31 print("Epoch", epoch, "iteration", it, "loss", loss.item()) 32 33 34 print("Epoch", epoch, "Training loss", total_loss/total_num_words) 35 if epoch % 5 == 0: 36 evaluate(model, dev_data)
训练100次:
1 train(model, train_data, num_epochs=100)
训练结果(training loss在不断下降):
1 Epoch 0 iteration 0 loss 8.084440231323242 2 Epoch 0 iteration 100 loss 4.8448944091796875 3 Epoch 0 iteration 200 loss 4.879772663116455 4 Epoch 0 Training loss 5.477221919210141 5 Evaluation loss 4.821030395389826 6 Epoch 1 iteration 0 loss 4.69868278503418 7 Epoch 1 iteration 100 loss 4.085171699523926 8 Epoch 1 iteration 200 loss 4.312857151031494 9 Epoch 1 Training loss 4.579521701350524 10 Epoch 2 iteration 0 loss 4.193971633911133 11 Epoch 2 iteration 100 loss 3.678673267364502 12 Epoch 2 iteration 200 loss 4.019515514373779 13 Epoch 2 Training loss 4.186071368925457 14 Epoch 3 iteration 0 loss 3.8352835178375244 15 Epoch 3 iteration 100 loss 3.3954527378082275 16 Epoch 3 iteration 200 loss 3.774580240249634 17 Epoch 3 Training loss 3.9222166424267986 18 Epoch 4 iteration 0 loss 3.585063934326172 19 Epoch 4 iteration 100 loss 3.215750217437744 20 Epoch 4 iteration 200 loss 3.626997232437134 21 Epoch 4 Training loss 3.722608096150466 22 Epoch 5 iteration 0 loss 3.411375045776367 23 Epoch 5 iteration 100 loss 3.0424859523773193 24 Epoch 5 iteration 200 loss 3.492255926132202 25 Epoch 5 Training loss 3.5699179079587195 26 Evaluation loss 3.655821240952787 27 Epoch 6 iteration 0 loss 3.273927927017212 28 Epoch 6 iteration 100 loss 2.897022247314453 29 Epoch 6 iteration 200 loss 3.355715036392212 30 Epoch 6 Training loss 3.4411540739967426 31 Epoch 7 iteration 0 loss 3.16508412361145 32 Epoch 7 iteration 100 loss 2.7818763256073 33 Epoch 7 iteration 200 loss 3.241000175476074 34 Epoch 7 Training loss 3.330995073153501 35 Epoch 8 iteration 0 loss 3.081458806991577 36 Epoch 8 iteration 100 loss 2.692844867706299 37 Epoch 8 iteration 200 loss 3.159105062484741 38 Epoch 8 Training loss 3.237538761219645 39 Epoch 9 iteration 0 loss 2.983361005783081 40 Epoch 9 iteration 100 loss 2.5852301120758057 41 Epoch 9 iteration 200 loss 3.076793670654297 42 Epoch 9 Training loss 3.1542968146839754 43 Epoch 10 iteration 0 loss 2.88155198097229 44 Epoch 10 iteration 100 loss 2.504387617111206 45 Epoch 10 iteration 200 loss 2.9708898067474365 46 Epoch 10 Training loss 3.0766581801071924 47 Evaluation loss 3.3804360915245204 48 Epoch 11 iteration 0 loss 2.805739164352417 49 Epoch 11 iteration 100 loss 2.417832612991333 50 Epoch 11 iteration 200 loss 2.9001076221466064 51 Epoch 11 Training loss 3.0072335865815747 52 Epoch 12 iteration 0 loss 2.7389864921569824 53 Epoch 12 iteration 100 loss 2.352132558822632 54 Epoch 12 iteration 200 loss 2.864527702331543 55 Epoch 12 Training loss 2.945309993148362 56 Epoch 13 iteration 0 loss 2.6841001510620117 57 Epoch 13 iteration 100 loss 2.2722346782684326 58 Epoch 13 iteration 200 loss 2.8002915382385254 59 Epoch 13 Training loss 2.8879525671218156 60 Epoch 14 iteration 0 loss 2.641491651535034 61 Epoch 14 iteration 100 loss 2.237807273864746 62 Epoch 14 iteration 200 loss 2.7538034915924072 63 Epoch 14 Training loss 2.833802188663957 64 Epoch 15 iteration 0 loss 2.5613601207733154 65 Epoch 15 iteration 100 loss 2.149299144744873 66 Epoch 15 iteration 200 loss 2.671037435531616 67 Epoch 15 Training loss 2.7850014679518598 68 Evaluation loss 3.2569677577366516 69 Epoch 16 iteration 0 loss 2.5330140590667725 70 Epoch 16 iteration 100 loss 2.0988974571228027 71 Epoch 16 iteration 200 loss 2.611022472381592 72 Epoch 16 Training loss 2.7354116963192716 73 Epoch 17 iteration 0 loss 2.485084295272827 74 Epoch 17 iteration 100 loss 2.0532665252685547 75 Epoch 17 iteration 200 loss 2.604226589202881 76 Epoch 17 Training loss 2.6934350694497957 77 Epoch 18 iteration 0 loss 2.4521820545196533 78 Epoch 18 iteration 100 loss 2.0395381450653076 79 Epoch 18 iteration 200 loss 2.5578808784484863 80 Epoch 18 Training loss 2.651303096776386 81 Epoch 19 iteration 0 loss 2.390338182449341 82 Epoch 19 iteration 100 loss 1.9780246019363403 83 Epoch 19 iteration 200 loss 2.5150232315063477 84 Epoch 19 Training loss 2.611681331448251 85 Epoch 20 iteration 0 loss 2.352649211883545 86 Epoch 20 iteration 100 loss 1.9426053762435913 87 Epoch 20 iteration 200 loss 2.4782586097717285 88 Epoch 20 Training loss 2.5747013451744616 89 Evaluation loss 3.194680030596711 90 Epoch 21 iteration 0 loss 2.3205008506774902 91 Epoch 21 iteration 100 loss 1.9143742322921753 92 Epoch 21 iteration 200 loss 2.4607479572296143 93 Epoch 21 Training loss 2.5404243457594116 94 Epoch 22 iteration 0 loss 2.3100969791412354 95 Epoch 22 iteration 100 loss 1.912932276725769 96 Epoch 22 iteration 200 loss 2.4103682041168213 97 Epoch 22 Training loss 2.507626390779296 98 Epoch 23 iteration 0 loss 2.228956699371338 99 Epoch 23 iteration 100 loss 1.8543353080749512 100 Epoch 23 iteration 200 loss 2.3663489818573 101 Epoch 23 Training loss 2.475231424650597 102 Epoch 24 iteration 0 loss 2.199277639389038 103 Epoch 24 iteration 100 loss 1.8272788524627686 104 Epoch 24 iteration 200 loss 2.3518714904785156 105 Epoch 24 Training loss 2.4439996520576863 106 Epoch 25 iteration 0 loss 2.198460817337036 107 Epoch 25 iteration 100 loss 1.7921738624572754 108 Epoch 25 iteration 200 loss 2.3299384117126465 109 Epoch 25 Training loss 2.416539151404694 110 Evaluation loss 3.1583419660450347 111 Epoch 26 iteration 0 loss 2.1647706031799316 112 Epoch 26 iteration 100 loss 1.725657343864441 113 Epoch 26 iteration 200 loss 2.268852710723877 114 Epoch 26 Training loss 2.3919890312051444 115 Epoch 27 iteration 0 loss 2.1400880813598633 116 Epoch 27 iteration 100 loss 1.7474910020828247 117 Epoch 27 iteration 200 loss 2.256742000579834 118 Epoch 27 Training loss 2.3595162004913086 119 Epoch 28 iteration 0 loss 2.0979115962982178 120 Epoch 28 iteration 100 loss 1.7000322341918945 121 Epoch 28 iteration 200 loss 2.2546005249023438 122 Epoch 28 Training loss 2.3335356415568618 123 Epoch 29 iteration 0 loss 2.1031572818756104 124 Epoch 29 iteration 100 loss 1.6599613428115845 125 Epoch 29 iteration 200 loss 2.2020833492279053 126 Epoch 29 Training loss 2.311978717884133 127 Epoch 30 iteration 0 loss 2.041980028152466 128 Epoch 30 iteration 100 loss 1.6663353443145752 129 Epoch 30 iteration 200 loss 2.1463098526000977 130 Epoch 30 Training loss 2.2902015222655807 131 Evaluation loss 3.133273747140961 132 Epoch 31 iteration 0 loss 2.0045719146728516 133 Epoch 31 iteration 100 loss 1.6515719890594482 134 Epoch 31 iteration 200 loss 2.1130664348602295 135 Epoch 31 Training loss 2.2633183437027657 136 Epoch 32 iteration 0 loss 1.9948643445968628 137 Epoch 32 iteration 100 loss 1.6262538433074951 138 Epoch 32 iteration 200 loss 2.1329450607299805 139 Epoch 32 Training loss 2.242057023454951 140 Epoch 33 iteration 0 loss 1.9623773097991943 141 Epoch 33 iteration 100 loss 1.6022558212280273 142 Epoch 33 iteration 200 loss 2.092766523361206 143 Epoch 33 Training loss 2.219300144243463 144 Epoch 34 iteration 0 loss 1.929176688194275 145 Epoch 34 iteration 100 loss 1.57985258102417 146 Epoch 34 iteration 200 loss 2.067972183227539 147 Epoch 34 Training loss 2.199957146669663 148 Epoch 35 iteration 0 loss 1.9449653625488281 149 Epoch 35 iteration 100 loss 1.5760831832885742 150 Epoch 35 iteration 200 loss 2.056731939315796 151 Epoch 35 Training loss 2.1790822226814464 152 Evaluation loss 3.13363336627263 153 Epoch 36 iteration 0 loss 1.8961074352264404 154 Epoch 36 iteration 100 loss 1.5195672512054443 155 Epoch 36 iteration 200 loss 2.0268213748931885 156 Epoch 36 Training loss 2.160204240618562 157 Epoch 37 iteration 0 loss 1.9172203540802002 158 Epoch 37 iteration 100 loss 1.495902180671692 159 Epoch 37 iteration 200 loss 1.9827772378921509 160 Epoch 37 Training loss 2.139063811380212 161 Epoch 38 iteration 0 loss 1.8988227844238281 162 Epoch 38 iteration 100 loss 1.5224453210830688 163 Epoch 38 iteration 200 loss 1.972291111946106 164 Epoch 38 Training loss 2.1211086652629887 165 Epoch 39 iteration 0 loss 1.8728121519088745 166 Epoch 39 iteration 100 loss 1.4476994276046753 167 Epoch 39 iteration 200 loss 1.9898269176483154 168 Epoch 39 Training loss 2.1024907934743258 169 Epoch 40 iteration 0 loss 1.8664008378982544 170 Epoch 40 iteration 100 loss 1.4997611045837402 171 Epoch 40 iteration 200 loss 1.9541966915130615 172 Epoch 40 Training loss 2.086313187411815 173 Evaluation loss 3.1282314096494708 174 Epoch 41 iteration 0 loss 1.865237832069397 175 Epoch 41 iteration 100 loss 1.4755399227142334 176 Epoch 41 iteration 200 loss 1.9337103366851807 177 Epoch 41 Training loss 2.068258631932244 178 Epoch 42 iteration 0 loss 1.790804147720337 179 Epoch 42 iteration 100 loss 1.4380069971084595 180 Epoch 42 iteration 200 loss 1.9523491859436035 181 Epoch 42 Training loss 2.0498001934027874 182 Epoch 43 iteration 0 loss 1.7979768514633179 183 Epoch 43 iteration 100 loss 1.436006784439087 184 Epoch 43 iteration 200 loss 1.9101322889328003 185 Epoch 43 Training loss 2.0354298580230195 186 Epoch 44 iteration 0 loss 1.7717180252075195 187 Epoch 44 iteration 100 loss 1.412601351737976 188 Epoch 44 iteration 200 loss 1.8883790969848633 189 Epoch 44 Training loss 2.0182710578663032 190 Epoch 45 iteration 0 loss 1.7614871263504028 191 Epoch 45 iteration 100 loss 1.3429900407791138 192 Epoch 45 iteration 200 loss 1.862486720085144 193 Epoch 45 Training loss 2.0034489605129595 194 Evaluation loss 3.13050353642062 195 Epoch 46 iteration 0 loss 1.753187656402588 196 Epoch 46 iteration 100 loss 1.3810824155807495 197 Epoch 46 iteration 200 loss 1.8526273965835571 198 Epoch 46 Training loss 1.9899710891643612 199 Epoch 47 iteration 0 loss 1.7567869424819946 200 Epoch 47 iteration 100 loss 1.3430988788604736 201 Epoch 47 iteration 200 loss 1.8135911226272583 202 Epoch 47 Training loss 1.9723690433387957 203 Epoch 48 iteration 0 loss 1.7263280153274536 204 Epoch 48 iteration 100 loss 1.3430798053741455 205 Epoch 48 iteration 200 loss 1.8229252099990845 206 Epoch 48 Training loss 1.9580909331705005 207 Epoch 49 iteration 0 loss 1.731834888458252 208 Epoch 49 iteration 100 loss 1.325390100479126 209 Epoch 49 iteration 200 loss 1.8075029850006104 210 Epoch 49 Training loss 1.9418853706725143 211 Epoch 50 iteration 0 loss 1.7218893766403198 212 Epoch 50 iteration 100 loss 1.2710607051849365 213 Epoch 50 iteration 200 loss 1.8196479082107544 214 Epoch 50 Training loss 1.9300463292027463 215 Evaluation loss 3.1402900424368902 216 Epoch 51 iteration 0 loss 1.701721429824829 217 Epoch 51 iteration 100 loss 1.2720820903778076 218 Epoch 51 iteration 200 loss 1.7759710550308228 219 Epoch 51 Training loss 1.9192517232508806 220 Epoch 52 iteration 0 loss 1.7286512851715088 221 Epoch 52 iteration 100 loss 1.2737478017807007 222 Epoch 52 iteration 200 loss 1.7545547485351562 223 Epoch 52 Training loss 1.906238278183267 224 Epoch 53 iteration 0 loss 1.6672327518463135 225 Epoch 53 iteration 100 loss 1.3138436079025269 226 Epoch 53 iteration 200 loss 1.8045201301574707 227 Epoch 53 Training loss 1.8922825534741075 228 Epoch 54 iteration 0 loss 1.617557168006897 229 Epoch 54 iteration 100 loss 1.22885262966156 230 Epoch 54 iteration 200 loss 1.7750707864761353 231 Epoch 54 Training loss 1.8807705430479014 232 Epoch 55 iteration 0 loss 1.66348135471344 233 Epoch 55 iteration 100 loss 1.2331219911575317 234 Epoch 55 iteration 200 loss 1.7303975820541382 235 Epoch 55 Training loss 1.867195544079556 236 Evaluation loss 3.145431456349013 237 Epoch 56 iteration 0 loss 1.6259342432022095 238 Epoch 56 iteration 100 loss 1.2141388654708862 239 Epoch 56 iteration 200 loss 1.6984847784042358 240 Epoch 56 Training loss 1.8548133653506713 241 Epoch 57 iteration 0 loss 1.605487585067749 242 Epoch 57 iteration 100 loss 1.1920335292816162 243 Epoch 57 iteration 200 loss 1.7253336906433105 244 Epoch 57 Training loss 1.8387836396466541 245 Epoch 58 iteration 0 loss 1.600136160850525 246 Epoch 58 iteration 100 loss 1.2192472219467163 247 Epoch 58 iteration 200 loss 1.6888371706008911 248 Epoch 58 Training loss 1.83046734055076 249 Epoch 59 iteration 0 loss 1.6042535305023193 250 Epoch 59 iteration 100 loss 1.2362377643585205 251 Epoch 59 iteration 200 loss 1.6654771566390991 252 Epoch 59 Training loss 1.8226244935892273 253 Epoch 60 iteration 0 loss 1.5602766275405884 254 Epoch 60 iteration 100 loss 1.201045036315918 255 Epoch 60 iteration 200 loss 1.6702684164047241 256 Epoch 60 Training loss 1.8102721190615219 257 Evaluation loss 3.154303393916162 258 Epoch 61 iteration 0 loss 1.5679781436920166 259 Epoch 61 iteration 100 loss 1.2105367183685303 260 Epoch 61 iteration 200 loss 1.6650742292404175 261 Epoch 61 Training loss 1.7970227477404426 262 Epoch 62 iteration 0 loss 1.5734565258026123 263 Epoch 62 iteration 100 loss 1.1602052450180054 264 Epoch 62 iteration 200 loss 1.583187222480774 265 Epoch 62 Training loss 1.787027303402099 266 Epoch 63 iteration 0 loss 1.563283920288086 267 Epoch 63 iteration 100 loss 1.1829460859298706 268 Epoch 63 iteration 200 loss 1.6458944082260132 269 Epoch 63 Training loss 1.7742324239103342 270 Epoch 64 iteration 0 loss 1.5429617166519165 271 Epoch 64 iteration 100 loss 1.1225509643554688 272 Epoch 64 iteration 200 loss 1.6353931427001953 273 Epoch 64 Training loss 1.7665018986396424 274 Epoch 65 iteration 0 loss 1.5284583568572998 275 Epoch 65 iteration 100 loss 1.1426113843917847 276 Epoch 65 iteration 200 loss 1.6138485670089722 277 Epoch 65 Training loss 1.7557591437816458 278 Evaluation loss 3.166533922994568 279 Epoch 66 iteration 0 loss 1.5184751749038696 280 Epoch 66 iteration 100 loss 1.127056360244751 281 Epoch 66 iteration 200 loss 1.611910343170166 282 Epoch 66 Training loss 1.7446940747065838 283 Epoch 67 iteration 0 loss 1.4880752563476562 284 Epoch 67 iteration 100 loss 1.1075133085250854 285 Epoch 67 iteration 200 loss 1.6138321161270142 286 Epoch 67 Training loss 1.7374662356132202 287 Epoch 68 iteration 0 loss 1.5260978937149048 288 Epoch 68 iteration 100 loss 1.12235689163208 289 Epoch 68 iteration 200 loss 1.6129950284957886 290 Epoch 68 Training loss 1.7253250324901928 291 Epoch 69 iteration 0 loss 1.5172449350357056 292 Epoch 69 iteration 100 loss 1.1174883842468262 293 Epoch 69 iteration 200 loss 1.551174283027649 294 Epoch 69 Training loss 1.7166664929363027 295 Epoch 70 iteration 0 loss 1.5006300210952759 296 Epoch 70 iteration 100 loss 1.0905342102050781 297 Epoch 70 iteration 200 loss 1.5446460247039795 298 Epoch 70 Training loss 1.70989819337649 299 Evaluation loss 3.1750113054724385 300 Epoch 71 iteration 0 loss 1.4726097583770752 301 Epoch 71 iteration 100 loss 1.086822509765625 302 Epoch 71 iteration 200 loss 1.5575647354125977 303 Epoch 71 Training loss 1.697000935158525 304 Epoch 72 iteration 0 loss 1.449334979057312 305 Epoch 72 iteration 100 loss 1.0667144060134888 306 Epoch 72 iteration 200 loss 1.530726671218872 307 Epoch 72 Training loss 1.6881878283419123 308 Epoch 73 iteration 0 loss 1.4603246450424194 309 Epoch 73 iteration 100 loss 1.0751914978027344 310 Epoch 73 iteration 200 loss 1.5088605880737305 311 Epoch 73 Training loss 1.6805761044806562 312 Epoch 74 iteration 0 loss 1.4748084545135498 313 Epoch 74 iteration 100 loss 1.0556395053863525 314 Epoch 74 iteration 200 loss 1.5206905603408813 315 Epoch 74 Training loss 1.6673887956853506 316 Epoch 75 iteration 0 loss 1.454646348953247 317 Epoch 75 iteration 100 loss 1.0396276712417603 318 Epoch 75 iteration 200 loss 1.518398404121399 319 Epoch 75 Training loss 1.6633919350661184 320 Evaluation loss 3.189181657332237 321 Epoch 76 iteration 0 loss 1.4616646766662598 322 Epoch 76 iteration 100 loss 0.9838554859161377 323 Epoch 76 iteration 200 loss 1.4613702297210693 324 Epoch 76 Training loss 1.6526747506920867 325 Epoch 77 iteration 0 loss 1.4646761417388916 326 Epoch 77 iteration 100 loss 1.0383753776550293 327 Epoch 77 iteration 200 loss 1.5081768035888672 328 Epoch 77 Training loss 1.6462943129725018 329 Epoch 78 iteration 0 loss 1.4008097648620605 330 Epoch 78 iteration 100 loss 1.0147686004638672 331 Epoch 78 iteration 200 loss 1.5017434358596802 332 Epoch 78 Training loss 1.6352284007247493 333 Epoch 79 iteration 0 loss 1.4189144372940063 334 Epoch 79 iteration 100 loss 1.0126101970672607 335 Epoch 79 iteration 200 loss 1.4195480346679688 336 Epoch 79 Training loss 1.628015456811747 337 Epoch 80 iteration 0 loss 1.4199804067611694 338 Epoch 80 iteration 100 loss 1.0256879329681396 339 Epoch 80 iteration 200 loss 1.4564563035964966 340 Epoch 80 Training loss 1.6227562783981957 341 Evaluation loss 3.2074876046135703 342 Epoch 81 iteration 0 loss 1.431972622871399 343 Epoch 81 iteration 100 loss 1.0110960006713867 344 Epoch 81 iteration 200 loss 1.4414775371551514 345 Epoch 81 Training loss 1.6157781071711008 346 Epoch 82 iteration 0 loss 1.4158073663711548 347 Epoch 82 iteration 100 loss 0.9702512621879578 348 Epoch 82 iteration 200 loss 1.4209394454956055 349 Epoch 82 Training loss 1.605166310639776 350 Epoch 83 iteration 0 loss 1.3871146440505981 351 Epoch 83 iteration 100 loss 1.0183656215667725 352 Epoch 83 iteration 200 loss 1.4292359352111816 353 Epoch 83 Training loss 1.5961119023327037 354 Epoch 84 iteration 0 loss 1.3919366598129272 355 Epoch 84 iteration 100 loss 0.9692129492759705 356 Epoch 84 iteration 200 loss 1.4092985391616821 357 Epoch 84 Training loss 1.5897755956223851 358 Epoch 85 iteration 0 loss 1.355398416519165 359 Epoch 85 iteration 100 loss 0.9916797280311584 360 Epoch 85 iteration 200 loss 1.423561453819275 361 Epoch 85 Training loss 1.5878568289810793 362 Evaluation loss 3.2138472480503295 363 Epoch 86 iteration 0 loss 1.351928472518921 364 Epoch 86 iteration 100 loss 0.9997824430465698 365 Epoch 86 iteration 200 loss 1.4049323797225952 366 Epoch 86 Training loss 1.5719682346027806 367 Epoch 87 iteration 0 loss 1.3508714437484741 368 Epoch 87 iteration 100 loss 0.9411044716835022 369 Epoch 87 iteration 200 loss 1.4019731283187866 370 Epoch 87 Training loss 1.5641802139809575 371 Epoch 88 iteration 0 loss 1.347946047782898 372 Epoch 88 iteration 100 loss 0.9493017792701721 373 Epoch 88 iteration 200 loss 1.3770906925201416 374 Epoch 88 Training loss 1.5587840858982533 375 Epoch 89 iteration 0 loss 1.320084571838379 376 Epoch 89 iteration 100 loss 0.9223963022232056 377 Epoch 89 iteration 200 loss 1.4065088033676147 378 Epoch 89 Training loss 1.5548267858027334 379 Epoch 90 iteration 0 loss 1.3534889221191406 380 Epoch 90 iteration 100 loss 0.9281108975410461 381 Epoch 90 iteration 200 loss 1.3821330070495605 382 Epoch 90 Training loss 1.5474867314671616 383 Evaluation loss 3.2276618163204667 384 Epoch 91 iteration 0 loss 1.3667511940002441 385 Epoch 91 iteration 100 loss 0.8797598481178284 386 Epoch 91 iteration 200 loss 1.3776274919509888 387 Epoch 91 Training loss 1.536482189982952 388 Epoch 92 iteration 0 loss 1.3355433940887451 389 Epoch 92 iteration 100 loss 0.9130176901817322 390 Epoch 92 iteration 200 loss 1.3042923212051392 391 Epoch 92 Training loss 1.5308507835779057 392 Epoch 93 iteration 0 loss 1.2953367233276367 393 Epoch 93 iteration 100 loss 0.9194003939628601 394 Epoch 93 iteration 200 loss 1.3469970226287842 395 Epoch 93 Training loss 1.519625581403501 396 Epoch 94 iteration 0 loss 1.322600245475769 397 Epoch 94 iteration 100 loss 0.9003701210021973 398 Epoch 94 iteration 200 loss 1.3512846231460571 399 Epoch 94 Training loss 1.5193673748787049 400 Epoch 95 iteration 0 loss 1.2789180278778076 401 Epoch 95 iteration 100 loss 0.9352515339851379 402 Epoch 95 iteration 200 loss 1.3609877824783325 403 Epoch 95 Training loss 1.5135782739054082 404 Evaluation loss 3.2474015759319284 405 Epoch 96 iteration 0 loss 1.3051612377166748 406 Epoch 96 iteration 100 loss 0.8885603547096252 407 Epoch 96 iteration 200 loss 1.3272497653961182 408 Epoch 96 Training loss 1.5079536183100883 409 Epoch 97 iteration 0 loss 1.2671339511871338 410 Epoch 97 iteration 100 loss 0.8706735968589783 411 Epoch 97 iteration 200 loss 1.305412769317627 412 Epoch 97 Training loss 1.4974833326540824 413 Epoch 98 iteration 0 loss 1.308292269706726 414 Epoch 98 iteration 100 loss 0.9079441428184509 415 Epoch 98 iteration 200 loss 1.2940715551376343 416 Epoch 98 Training loss 1.4928753682563118 417 Epoch 99 iteration 0 loss 1.276250958442688 418 Epoch 99 iteration 100 loss 0.890657901763916 419 Epoch 99 iteration 200 loss 1.3286609649658203 420 Epoch 99 Training loss 1.4852960116094391
1 def translate_dev(i): 2 en_sent = " ".join([inv_en_dict[w] for w in dev_en[i]]) #原来的英文 3 print(en_sent) 4 cn_sent = " ".join([inv_cn_dict[w] for w in dev_cn[i]]) #原来的中文 5 print("".join(cn_sent)) 6 7 mb_x = torch.from_numpy(np.array(dev_en[i]).reshape(1, -1)).long().to(device) 8 mb_x_len = torch.from_numpy(np.array([len(dev_en[i])])).long().to(device) 9 bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device) 10 11 translation, attn = model.translate(mb_x, mb_x_len, bos) 12 translation = [inv_cn_dict[i] for i in translation.data.cpu().numpy().reshape(-1)] 13 trans = [] 14 for word in translation: 15 if word != "EOS": 16 trans.append(word) 17 else: 18 break 19 print("".join(trans)) #翻译后的中文 20 21 for i in range(100, 120): 22 translate_dev(i) 23 print()
执行结果如下(样本少加上训练时间太短造成翻译效果不太好):
1 BOS you have nice skin . EOS 2 BOS 你 的 皮 膚 真 好 。 EOS 3 你有很多好。 4 5 BOS you ‘re UNK correct . EOS 6 BOS 你 部 分 正 确 。 EOS 7 你是个好人。 8 9 BOS everyone admired his courage . EOS 10 BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS 11 他們的電話讓他們一個 12 13 BOS what time is it ? EOS 14 BOS 几 点 了 ? EOS 15 它还是什么? 16 17 BOS i ‘m free tonight . EOS 18 BOS 我 今 晚 有 空 。 EOS 19 我今晚有空。 20 21 BOS here is your book . EOS 22 BOS 這 是 你 的 書 。 EOS 23 你的書桌是舊。 24 25 BOS they are at lunch . EOS 26 BOS 他 们 在 吃 午 饭 。 EOS 27 他们正在吃米饭。 28 29 BOS this chair is UNK . EOS 30 BOS 這 把 椅 子 很 UNK 。 EOS 31 這是真的最好的人。 32 33 BOS it ‘s pretty heavy . EOS 34 BOS 它 真 重 。 EOS 35 它是真的。 36 37 BOS many attended his funeral . EOS 38 BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS 39 AI他的襪子。 40 41 BOS training will be provided . EOS 42 BOS 会 有 训 练 。 EOS 43 人们的货品造成的。 44 45 BOS someone is watching you . EOS 46 BOS 有 人 在 看 著 你 。 EOS 47 有人叫醒汤姆。 48 49 BOS i slapped his face . EOS 50 BOS 我 摑 了 他 的 臉 。 EOS 51 我有他的兄弟。 52 53 BOS i like UNK music . EOS 54 BOS 我 喜 歡 流 行 音 樂 。 EOS 55 我喜欢狗在家。 56 57 BOS tom had no children . EOS 58 BOS T o m 沒 有 孩 子 。 EOS 59 汤姆不需要做什么。 60 61 BOS please lock the door . EOS 62 BOS 請 把 門 鎖 上 。 EOS 63 請把門打開。 64 65 BOS tom has calmed down . EOS 66 BOS 汤 姆 冷 静 下 来 了 。 EOS 67 汤姆睡着了。 68 69 BOS please speak more loudly . EOS 70 BOS 請 說 大 聲 一 點 兒 。 EOS 71 請講慢一點。 72 73 BOS keep next sunday free . EOS 74 BOS 把 下 周 日 空 出 来 。 EOS 75 下午可以轉下。 76 77 BOS i made a mistake . EOS 78 BOS 我 犯 了 一 個 錯 。 EOS 79 我有些意生。
Encoder模型的任务是把输入文字传入embedding层和GRU层,转换成一些hidden states作为后续的context vectors;
1 class Encoder(nn.Module): 2 def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2): 3 super(Encoder, self).__init__() 4 self.embed = nn.Embedding(vocab_size, embed_size) 5 self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True) 6 self.dropout = nn.Dropout(dropout) 7 self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size) 8 9 def forward(self, x, lengths): 10 sorted_len, sorted_idx = lengths.sort(0, descending=True) 11 x_sorted = x[sorted_idx.long()] 12 embedded = self.dropout(self.embed(x_sorted)) 13 14 packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True) 15 packed_out, hid = self.rnn(packed_embedded) 16 out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) 17 _, original_idx = sorted_idx.sort(0, descending=False) 18 out = out[original_idx.long()].contiguous() 19 hid = hid[:, original_idx.long()].contiguous() 20 21 hid = torch.cat([hid[-2], hid[-1]], dim=1) #双向,所以拼接 22 hid = torch.tanh(self.fc(hid)).unsqueeze(0) 23 24 return out, hid
图中ht是output,hs是context,计算score时使用第二种方法;
根据context vectors和当前的输出hidden states,计算输出;
1 class Attention(nn.Module): 2 def __init__(self, enc_hidden_size, dec_hidden_size): 3 super(Attention, self).__init__() 4 5 self.enc_hidden_size = enc_hidden_size 6 self.dec_hidden_size = dec_hidden_size 7 8 self.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False) #线性变换 9 self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size) 10 11 def forward(self, output, context, mask): 12 # output: [batch_size, output_len, dec_hidden_size] 13 # context: [batch_size, context_len, 2*enc_hidden_size] 14 15 batch_size = output.size(0) 16 output_len = output.size(1) 17 context_len = context.size(1) 18 19 #context_in: [batch_size, context_len, dec_hidden_size] 20 context_in = self.linear_in(context.view(batch_size*context_len, -1)).view(batch_size, context_len, -1) 21 22 # context_in.transpose(1,2): [batch_size, dec_hidden_size, context_len] 23 attn = torch.bmm(output, context_in.transpose(1,2)) #attn: [batch_size, output_len, context_len] 24 25 attn.data.masked_fill(mask, -1e6) 26 27 attn = F.softmax(attn, dim=2) #attn: [batch_size, output_len, context_len] 28 29 context = torch.bmm(attn, context) #context: [batch_size, output_len, enc_hidden_size] 30 31 output = torch.cat((context, output), dim=2) #output: [batch_size, output_len, hidden_size*2] 32 33 output = output.view(batch_size*output_len, -1) 34 output = torch.tanh(self.linear_out(output)) 35 output = output.view(batch_size, output_len, -1) 36 37 return output, attn
Decoder会根据已经翻译的句子内容和context vectors,来决定下一个输出的单词;
1 class Decoder(nn.Module): 2 def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2): 3 super(Decoder, self).__init__() 4 self.embed = nn.Embedding(vocab_size, embed_size) 5 self.attention = Attention(enc_hidden_size, dec_hidden_size) 6 self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True) 7 self.out = nn.Linear(dec_hidden_size, vocab_size) 8 self.dropout = nn.Dropout(dropout) 9 10 def create_mask(self, x_len, y_len): # a mask of shape x_len * y_len 11 12 device = x_len.device 13 max_x_len = x_len.max() 14 max_y_len = y_len.max() 15 x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None] 16 y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None] 17 mask = (1 - x_mask[:, :, None] * y_mask[:, None, :]).byte() 18 return mask 19 20 def forward(self, ctx, ctx_lengths, y, y_lengths, hid): 21 sorted_len, sorted_idx = y_lengths.sort(0, descending=True) 22 y_sorted = y[sorted_idx.long()] 23 hid = hid[:, sorted_idx.long()] 24 25 y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_size 26 27 packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True) 28 out, hid = self.rnn(packed_seq, hid) 29 unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) 30 _, original_idx = sorted_idx.sort(0, descending=False) 31 output_seq = unpacked[original_idx.long()].contiguous() 32 hid = hid[:, original_idx.long()].contiguous() 33 34 mask = self.create_mask(y_lengths, ctx_lengths) 35 36 output, attn = self.attention(output_seq, ctx, mask) #根据原来的output_seq和context来计算 37 output = F.log_softmax(self.out(output), -1) 38 39 return output, hid, attn
1 dropout = 0.2 2 embed_size = hidden_size = 100 3 encoder = Encoder(vocab_size=en_total_words, embed_size=embed_size, enc_hidden_size=hidden_size, dec_hidden_size=hidden_size, dropout=dropout) 4 decoder = Decoder(vocab_size=cn_total_words, embed_size=embed_size, enc_hidden_size=hidden_size, dec_hidden_size=hidden_size, dropout=dropout) 5 model = Seq2Seq(encoder, decoder) 6 model = model.to(device) 7 loss_fn = LanguageModelCriterion().to(device) 8 optimizer = torch.optim.Adam(model.parameters()) 9 10 train(model, train_data, num_epochs=100)
训练结果:
1 Epoch 0 iteration 0 loss 1.3026132583618164 2 Epoch 0 iteration 100 loss 0.8847191333770752 3 Epoch 0 iteration 200 loss 1.285671353340149 4 Epoch 0 Training loss 1.4803871257447 5 Evaluation loss 3.260634059314127 6 Epoch 1 iteration 0 loss 1.2434465885162354 7 Epoch 1 iteration 100 loss 0.8472312092781067 8 Epoch 1 iteration 200 loss 1.282746434211731 9 Epoch 1 Training loss 1.4731217075462495 10 Epoch 2 iteration 0 loss 1.2593930959701538 11 Epoch 2 iteration 100 loss 0.8484001159667969 12 Epoch 2 iteration 200 loss 1.2862968444824219 13 Epoch 2 Training loss 1.466728115795638 14 Epoch 3 iteration 0 loss 1.2554501295089722 15 Epoch 3 iteration 100 loss 0.9115875363349915 16 Epoch 3 iteration 200 loss 1.2563236951828003 17 Epoch 3 Training loss 1.4607854827943827 18 Epoch 4 iteration 0 loss 1.217956304550171 19 Epoch 4 iteration 100 loss 0.8641748428344727 20 Epoch 4 iteration 200 loss 1.2998305559158325 21 Epoch 4 Training loss 1.4587181747145395 22 Epoch 5 iteration 0 loss 1.258739709854126 23 Epoch 5 iteration 100 loss 0.8705984354019165 24 Epoch 5 iteration 200 loss 1.2102816104888916 25 Epoch 5 Training loss 1.4507371513452623 26 Evaluation loss 3.266629208261664 27 Epoch 6 iteration 0 loss 1.259811282157898 28 Epoch 6 iteration 100 loss 0.8492067456245422 29 Epoch 6 iteration 200 loss 1.3064922094345093 30 Epoch 6 Training loss 1.4432446458560053 31 Epoch 7 iteration 0 loss 1.2411160469055176 32 Epoch 7 iteration 100 loss 0.8373231291770935 33 Epoch 7 iteration 200 loss 1.2500189542770386 34 Epoch 7 Training loss 1.436364381060567 35 Epoch 8 iteration 0 loss 1.1868956089019775 36 Epoch 8 iteration 100 loss 0.814584493637085 37 Epoch 8 iteration 200 loss 1.2773609161376953 38 Epoch 8 Training loss 1.4354508132900903 39 Epoch 9 iteration 0 loss 1.2234464883804321 40 Epoch 9 iteration 100 loss 0.797888457775116 41 Epoch 9 iteration 200 loss 1.2435855865478516 42 Epoch 9 Training loss 1.424914875345232 43 Epoch 10 iteration 0 loss 1.2067270278930664 44 Epoch 10 iteration 100 loss 0.8425077795982361 45 Epoch 10 iteration 200 loss 1.2325958013534546 46 Epoch 10 Training loss 1.4212906077384722 47 Evaluation loss 3.2876189393276327 48 Epoch 11 iteration 0 loss 1.221406102180481 49 Epoch 11 iteration 100 loss 0.80806964635849 50 Epoch 11 iteration 200 loss 1.3028448820114136 51 Epoch 11 Training loss 1.4154276829998698 52 Epoch 12 iteration 0 loss 1.1890984773635864 53 Epoch 12 iteration 100 loss 0.827181875705719 54 Epoch 12 iteration 200 loss 1.1675362586975098 55 Epoch 12 Training loss 1.4132606964483012 56 Epoch 13 iteration 0 loss 1.2002121210098267 57 Epoch 13 iteration 100 loss 0.8232781291007996 58 Epoch 13 iteration 200 loss 1.2605061531066895 59 Epoch 13 Training loss 1.407515715564216 60 Epoch 14 iteration 0 loss 1.1855664253234863 61 Epoch 14 iteration 100 loss 0.8178666234016418 62 Epoch 14 iteration 200 loss 1.2378345727920532 63 Epoch 14 Training loss 1.3966619677770713 64 Epoch 15 iteration 0 loss 1.1885008811950684 65 Epoch 15 iteration 100 loss 0.7523401975631714 66 Epoch 15 iteration 200 loss 1.1757400035858154 67 Epoch 15 Training loss 1.3940533722612007 68 Evaluation loss 3.3011410674061716 69 Epoch 16 iteration 0 loss 1.185882806777954 70 Epoch 16 iteration 100 loss 0.8129084706306458 71 Epoch 16 iteration 200 loss 1.2022055387496948 72 Epoch 16 Training loss 1.3908352825348185 73 Epoch 17 iteration 0 loss 1.145820140838623 74 Epoch 17 iteration 100 loss 0.7933529615402222 75 Epoch 17 iteration 200 loss 1.1954973936080933 76 Epoch 17 Training loss 1.3862186002415022 77 Epoch 18 iteration 0 loss 1.1626101732254028 78 Epoch 18 iteration 100 loss 0.8041335940361023 79 Epoch 18 iteration 200 loss 1.1879560947418213 80 Epoch 18 Training loss 1.3828558699833502 81 Epoch 19 iteration 0 loss 1.1661605834960938 82 Epoch 19 iteration 100 loss 0.7746578454971313 83 Epoch 19 iteration 200 loss 1.167975902557373 84 Epoch 19 Training loss 1.3737090146397222 85 Epoch 20 iteration 0 loss 1.1992604732513428 86 Epoch 20 iteration 100 loss 0.7750277519226074 87 Epoch 20 iteration 200 loss 1.1533249616622925 88 Epoch 20 Training loss 1.3699699581049805 89 Evaluation loss 3.316624780553762 90 Epoch 21 iteration 0 loss 1.182730793952942 91 Epoch 21 iteration 100 loss 0.7664387822151184 92 Epoch 21 iteration 200 loss 1.1734970808029175 93 Epoch 21 Training loss 1.3634166858854262 94 Epoch 22 iteration 0 loss 1.1587318181991577 95 Epoch 22 iteration 100 loss 0.7660608291625977 96 Epoch 22 iteration 200 loss 1.1832681894302368 97 Epoch 22 Training loss 1.3601878647219552 98 Epoch 23 iteration 0 loss 1.123557209968567 99 Epoch 23 iteration 100 loss 0.7884796857833862 100 Epoch 23 iteration 200 loss 1.131569266319275 101 Epoch 23 Training loss 1.3543664767232568 102 Epoch 24 iteration 0 loss 1.1566004753112793 103 Epoch 24 iteration 100 loss 0.7894638180732727 104 Epoch 24 iteration 200 loss 1.1293442249298096 105 Epoch 24 Training loss 1.3513351050205646 106 Epoch 25 iteration 0 loss 1.1237646341323853 107 Epoch 25 iteration 100 loss 0.7442751526832581 108 Epoch 25 iteration 200 loss 1.1396199464797974 109 Epoch 25 Training loss 1.3436930389138495 110 Evaluation loss 3.331576243054354 111 Epoch 26 iteration 0 loss 1.1391510963439941 112 Epoch 26 iteration 100 loss 0.7658866047859192 113 Epoch 26 iteration 200 loss 1.130005121231079 114 Epoch 26 Training loss 1.3387227258896204 115 Epoch 27 iteration 0 loss 1.086417555809021 116 Epoch 27 iteration 100 loss 0.7512990236282349 117 Epoch 27 iteration 200 loss 1.1055928468704224 118 Epoch 27 Training loss 1.3332018813254016 119 Epoch 28 iteration 0 loss 1.1308163404464722 120 Epoch 28 iteration 100 loss 0.7653459310531616 121 Epoch 28 iteration 200 loss 1.1437530517578125 122 Epoch 28 Training loss 1.3318316266073582 123 Epoch 29 iteration 0 loss 1.1284910440444946 124 Epoch 29 iteration 100 loss 0.7385256886482239 125 Epoch 29 iteration 200 loss 1.076254963874817 126 Epoch 29 Training loss 1.327704983812702 127 Epoch 30 iteration 0 loss 1.1279666423797607 128 Epoch 30 iteration 100 loss 0.7510428428649902 129 Epoch 30 iteration 200 loss 1.10474693775177 130 Epoch 30 Training loss 1.3247037412152105 131 Evaluation loss 3.345638094832775 132 Epoch 31 iteration 0 loss 1.1144018173217773 133 Epoch 31 iteration 100 loss 0.7183322906494141 134 Epoch 31 iteration 200 loss 1.1657849550247192 135 Epoch 31 Training loss 1.3181022928511037 136 Epoch 32 iteration 0 loss 1.1624877452850342 137 Epoch 32 iteration 100 loss 0.6971022486686707 138 Epoch 32 iteration 200 loss 1.1033793687820435 139 Epoch 32 Training loss 1.313637083400949 140 Epoch 33 iteration 0 loss 1.0961930751800537 141 Epoch 33 iteration 100 loss 0.7509954571723938 142 Epoch 33 iteration 200 loss 1.0901885032653809 143 Epoch 33 Training loss 1.3105013603183797 144 Epoch 34 iteration 0 loss 1.0936028957366943 145 Epoch 34 iteration 100 loss 0.7300226092338562 146 Epoch 34 iteration 200 loss 1.094140648841858 147 Epoch 34 Training loss 1.3085180236466905 148 Epoch 35 iteration 0 loss 1.1358038187026978 149 Epoch 35 iteration 100 loss 0.6928472518920898 150 Epoch 35 iteration 200 loss 1.1031907796859741 151 Epoch 35 Training loss 1.2983291715098229 152 Evaluation loss 3.3654267819449917 153 Epoch 36 iteration 0 loss 1.0817443132400513 154 Epoch 36 iteration 100 loss 0.7034777998924255 155 Epoch 36 iteration 200 loss 1.1244701147079468 156 Epoch 36 Training loss 1.294685624884497 157 Epoch 37 iteration 0 loss 1.0067986249923706 158 Epoch 37 iteration 100 loss 0.6711763739585876 159 Epoch 37 iteration 200 loss 1.0877138376235962 160 Epoch 37 Training loss 1.2908666166705178 161 Epoch 38 iteration 0 loss 1.0796058177947998 162 Epoch 38 iteration 100 loss 0.6984289288520813 163 Epoch 38 iteration 200 loss 1.0992212295532227 164 Epoch 38 Training loss 1.289693898836594 165 Epoch 39 iteration 0 loss 1.1193760633468628 166 Epoch 39 iteration 100 loss 0.7441080212593079 167 Epoch 39 iteration 200 loss 1.0557031631469727 168 Epoch 39 Training loss 1.287817969907393 169 Epoch 40 iteration 0 loss 1.0878312587738037 170 Epoch 40 iteration 100 loss 0.7390894889831543 171 Epoch 40 iteration 200 loss 1.0931909084320068 172 Epoch 40 Training loss 1.281862247889987 173 Evaluation loss 3.3775776429435034 174 Epoch 41 iteration 0 loss 1.1135987043380737 175 Epoch 41 iteration 100 loss 0.6786257028579712 176 Epoch 41 iteration 200 loss 1.056801676750183 177 Epoch 41 Training loss 1.2759448500226234 178 Epoch 42 iteration 0 loss 1.0649049282073975 179 Epoch 42 iteration 100 loss 0.6774815320968628 180 Epoch 42 iteration 200 loss 1.0807018280029297 181 Epoch 42 Training loss 1.2713608683004023 182 Epoch 43 iteration 0 loss 1.0711919069290161 183 Epoch 43 iteration 100 loss 0.6655244827270508 184 Epoch 43 iteration 200 loss 1.0616692304611206 185 Epoch 43 Training loss 1.2693709204800718 186 Epoch 44 iteration 0 loss 1.0423146486282349 187 Epoch 44 iteration 100 loss 0.7055337429046631 188 Epoch 44 iteration 200 loss 1.0746649503707886 189 Epoch 44 Training loss 1.2649514066760854 190 Epoch 45 iteration 0 loss 1.0937353372573853 191 Epoch 45 iteration 100 loss 0.6939021348953247 192 Epoch 45 iteration 200 loss 1.1060905456542969 193 Epoch 45 Training loss 1.2591645727085945 194 Evaluation loss 3.393269126438251 195 Epoch 46 iteration 0 loss 1.1005926132202148 196 Epoch 46 iteration 100 loss 0.6948174238204956 197 Epoch 46 iteration 200 loss 1.0675958395004272 198 Epoch 46 Training loss 1.2555609077507983 199 Epoch 47 iteration 0 loss 1.0566778182983398 200 Epoch 47 iteration 100 loss 0.6904436349868774 201 Epoch 47 iteration 200 loss 1.0723766088485718 202 Epoch 47 Training loss 1.2552127211907091 203 Epoch 48 iteration 0 loss 1.0497757196426392 204 Epoch 48 iteration 100 loss 0.6351101398468018 205 Epoch 48 iteration 200 loss 1.0661102533340454 206 Epoch 48 Training loss 1.2479233140629313 207 Epoch 49 iteration 0 loss 1.0470858812332153 208 Epoch 49 iteration 100 loss 0.6707669496536255 209 Epoch 49 iteration 200 loss 1.063056230545044 210 Epoch 49 Training loss 1.2453716254928995 211 Epoch 50 iteration 0 loss 1.0854698419570923 212 Epoch 50 iteration 100 loss 0.6165581345558167 213 Epoch 50 iteration 200 loss 1.0804699659347534 214 Epoch 50 Training loss 1.243395479327141 215 Evaluation loss 3.4102849531750494 216 Epoch 51 iteration 0 loss 1.0279858112335205 217 Epoch 51 iteration 100 loss 0.6448107957839966 218 Epoch 51 iteration 200 loss 1.0390673875808716 219 Epoch 51 Training loss 1.2358123475125082 220 Epoch 52 iteration 0 loss 1.0429105758666992 221 Epoch 52 iteration 100 loss 0.7124451994895935 222 Epoch 52 iteration 200 loss 1.061672329902649 223 Epoch 52 Training loss 1.233326693902576 224 Epoch 53 iteration 0 loss 1.0357102155685425 225 Epoch 53 iteration 100 loss 0.6381393074989319 226 Epoch 53 iteration 200 loss 1.0036036968231201 227 Epoch 53 Training loss 1.2297246983027847 228 Epoch 54 iteration 0 loss 1.0590764284133911 229 Epoch 54 iteration 100 loss 0.6603832840919495 230 Epoch 54 iteration 200 loss 1.0215944051742554 231 Epoch 54 Training loss 1.227340017322883 232 Epoch 55 iteration 0 loss 1.0460106134414673 233 Epoch 55 iteration 100 loss 0.67122882604599 234 Epoch 55 iteration 200 loss 1.0344772338867188 235 Epoch 55 Training loss 1.2244369935263697 236 Evaluation loss 3.4193579365732183 237 Epoch 56 iteration 0 loss 1.032409429550171 238 Epoch 56 iteration 100 loss 0.6183319091796875 239 Epoch 56 iteration 200 loss 0.9782896637916565 240 Epoch 56 Training loss 1.2178285635372972 241 Epoch 57 iteration 0 loss 1.0382548570632935 242 Epoch 57 iteration 100 loss 0.6902874708175659 243 Epoch 57 iteration 200 loss 1.016508936882019 244 Epoch 57 Training loss 1.214647328633978 245 Epoch 58 iteration 0 loss 1.0595533847808838 246 Epoch 58 iteration 100 loss 0.6885846853256226 247 Epoch 58 iteration 200 loss 1.0221766233444214 248 Epoch 58 Training loss 1.2100419675457097 249 Epoch 59 iteration 0 loss 1.014621615409851 250 Epoch 59 iteration 100 loss 0.602800190448761 251 Epoch 59 iteration 200 loss 1.037442684173584 252 Epoch 59 Training loss 1.2131489746903632 253 Epoch 60 iteration 0 loss 1.0217640399932861 254 Epoch 60 iteration 100 loss 0.6246439814567566 255 Epoch 60 iteration 200 loss 1.00297212600708 256 Epoch 60 Training loss 1.204652290841725 257 Evaluation loss 3.4292290158399075 258 Epoch 61 iteration 0 loss 0.9992070198059082 259 Epoch 61 iteration 100 loss 0.645142138004303 260 Epoch 61 iteration 200 loss 0.9961024522781372 261 Epoch 61 Training loss 1.2066241352823563 262 Epoch 62 iteration 0 loss 0.9980950951576233 263 Epoch 62 iteration 100 loss 0.6504135131835938 264 Epoch 62 iteration 200 loss 1.000308632850647 265 Epoch 62 Training loss 1.1984729866171178 266 Epoch 63 iteration 0 loss 0.9869410395622253 267 Epoch 63 iteration 100 loss 0.6618863344192505 268 Epoch 63 iteration 200 loss 0.981200635433197 269 Epoch 63 Training loss 1.1967191445048035 270 Epoch 64 iteration 0 loss 0.9695953130722046 271 Epoch 64 iteration 100 loss 0.6359274387359619 272 Epoch 64 iteration 200 loss 0.9904515743255615 273 Epoch 64 Training loss 1.194521029171779 274 Epoch 65 iteration 0 loss 0.9505796432495117 275 Epoch 65 iteration 100 loss 0.6068794131278992 276 Epoch 65 iteration 200 loss 0.980348527431488 277 Epoch 65 Training loss 1.189270519480765 278 Evaluation loss 3.454153442993674 279 Epoch 66 iteration 0 loss 1.0304545164108276 280 Epoch 66 iteration 100 loss 0.6792566776275635 281 Epoch 66 iteration 200 loss 0.9789241552352905 282 Epoch 66 Training loss 1.1869953296382767 283 Epoch 67 iteration 0 loss 0.957666277885437 284 Epoch 67 iteration 100 loss 0.584879994392395 285 Epoch 67 iteration 200 loss 1.0174148082733154 286 Epoch 67 Training loss 1.184179561090835 287 Epoch 68 iteration 0 loss 1.043166995048523 288 Epoch 68 iteration 100 loss 0.6168758869171143 289 Epoch 68 iteration 200 loss 1.0030053853988647 290 Epoch 68 Training loss 1.1824355462851552 291 Epoch 69 iteration 0 loss 1.0165300369262695 292 Epoch 69 iteration 100 loss 0.6542645692825317 293 Epoch 69 iteration 200 loss 1.0191236734390259 294 Epoch 69 Training loss 1.176021675397731 295 Epoch 70 iteration 0 loss 0.9590736031532288 296 Epoch 70 iteration 100 loss 0.6157773733139038 297 Epoch 70 iteration 200 loss 1.0451829433441162 298 Epoch 70 Training loss 1.1732503092442255 299 Evaluation loss 3.4715566423642277 300 Epoch 71 iteration 0 loss 0.971733570098877 301 Epoch 71 iteration 100 loss 0.5589802265167236 302 Epoch 71 iteration 200 loss 1.0018212795257568 303 Epoch 71 Training loss 1.1694891346833023 304 Epoch 72 iteration 0 loss 1.0042874813079834 305 Epoch 72 iteration 100 loss 0.6543828248977661 306 Epoch 72 iteration 200 loss 0.968835175037384 307 Epoch 72 Training loss 1.1667191714442264 308 Epoch 73 iteration 0 loss 0.9512341022491455 309 Epoch 73 iteration 100 loss 0.5809782147407532 310 Epoch 73 iteration 200 loss 0.9460022449493408 311 Epoch 73 Training loss 1.165780372824424 312 Epoch 74 iteration 0 loss 0.9838390946388245 313 Epoch 74 iteration 100 loss 0.6115572452545166 314 Epoch 74 iteration 200 loss 0.9821975827217102 315 Epoch 74 Training loss 1.1619031632185661 316 Epoch 75 iteration 0 loss 0.9615085124969482 317 Epoch 75 iteration 100 loss 0.5715279579162598 318 Epoch 75 iteration 200 loss 0.9673617482185364 319 Epoch 75 Training loss 1.1592393025041507 320 Evaluation loss 3.4810480487503015 321 Epoch 76 iteration 0 loss 0.9920525550842285 322 Epoch 76 iteration 100 loss 0.6243174076080322 323 Epoch 76 iteration 200 loss 0.9598985910415649 324 Epoch 76 Training loss 1.1506768550866349 325 Epoch 77 iteration 0 loss 0.9717826843261719 326 Epoch 77 iteration 100 loss 0.5903583765029907 327 Epoch 77 iteration 200 loss 0.9472079873085022 328 Epoch 77 Training loss 1.151601228059984 329 Epoch 78 iteration 0 loss 0.9331899881362915 330 Epoch 78 iteration 100 loss 0.6189018487930298 331 Epoch 78 iteration 200 loss 0.9951513409614563 332 Epoch 78 Training loss 1.1474281610158772 333 Epoch 79 iteration 0 loss 0.9012037515640259 334 Epoch 79 iteration 100 loss 0.5837778449058533 335 Epoch 79 iteration 200 loss 0.9066386818885803 336 Epoch 79 Training loss 1.142656700489289 337 Epoch 80 iteration 0 loss 0.9931736588478088 338 Epoch 80 iteration 100 loss 0.5927265882492065 339 Epoch 80 iteration 200 loss 0.938447892665863 340 Epoch 80 Training loss 1.1434640075302192 341 Evaluation loss 3.491394963812294 342 Epoch 81 iteration 0 loss 0.9227023720741272 343 Epoch 81 iteration 100 loss 0.5467157363891602 344 Epoch 81 iteration 200 loss 0.9126712083816528 345 Epoch 81 Training loss 1.1427882346320761 346 Epoch 82 iteration 0 loss 0.9733406901359558 347 Epoch 82 iteration 100 loss 0.564643144607544 348 Epoch 82 iteration 200 loss 0.9918593764305115 349 Epoch 82 Training loss 1.1362837826371996 350 Epoch 83 iteration 0 loss 0.9489978551864624 351 Epoch 83 iteration 100 loss 0.5791521668434143 352 Epoch 83 iteration 200 loss 0.9270768165588379 353 Epoch 83 Training loss 1.136011173156159 354 Epoch 84 iteration 0 loss 0.9410436749458313 355 Epoch 84 iteration 100 loss 0.5409624576568604 356 Epoch 84 iteration 200 loss 0.8918321132659912 357 Epoch 84 Training loss 1.128506034776273 358 Epoch 85 iteration 0 loss 0.9554007053375244 359 Epoch 85 iteration 100 loss 0.571331799030304 360 Epoch 85 iteration 200 loss 0.9672144055366516 361 Epoch 85 Training loss 1.1277133646535586 362 Evaluation loss 3.5075310684848158 363 Epoch 86 iteration 0 loss 0.9104467630386353 364 Epoch 86 iteration 100 loss 0.5656437277793884 365 Epoch 86 iteration 200 loss 0.9324206113815308 366 Epoch 86 Training loss 1.126375005188875 367 Epoch 87 iteration 0 loss 0.9339620471000671 368 Epoch 87 iteration 100 loss 0.5636867880821228 369 Epoch 87 iteration 200 loss 0.8825109601020813 370 Epoch 87 Training loss 1.1222316938253494 371 Epoch 88 iteration 0 loss 0.904504120349884 372 Epoch 88 iteration 100 loss 0.5706378221511841 373 Epoch 88 iteration 200 loss 0.9415532350540161 374 Epoch 88 Training loss 1.1215731092872845 375 Epoch 89 iteration 0 loss 0.9489354491233826 376 Epoch 89 iteration 100 loss 0.6389216184616089 377 Epoch 89 iteration 200 loss 0.8783397078514099 378 Epoch 89 Training loss 1.1199876689692458 379 Epoch 90 iteration 0 loss 0.909376323223114 380 Epoch 90 iteration 100 loss 0.6190019249916077 381 Epoch 90 iteration 200 loss 0.9191233515739441 382 Epoch 90 Training loss 1.1181392741798546 383 Evaluation loss 3.508678977926201 384 Epoch 91 iteration 0 loss 0.9080389738082886 385 Epoch 91 iteration 100 loss 0.5580074191093445 386 Epoch 91 iteration 200 loss 0.9494779706001282 387 Epoch 91 Training loss 1.1147855064570311 388 Epoch 92 iteration 0 loss 0.900802731513977 389 Epoch 92 iteration 100 loss 0.573580801486969 390 Epoch 92 iteration 200 loss 0.9199456572532654 391 Epoch 92 Training loss 1.1107969536786537 392 Epoch 93 iteration 0 loss 0.9345868229866028 393 Epoch 93 iteration 100 loss 0.5590959787368774 394 Epoch 93 iteration 200 loss 0.90354984998703 395 Epoch 93 Training loss 1.105984925602608 396 Epoch 94 iteration 0 loss 0.9008861780166626 397 Epoch 94 iteration 100 loss 0.5503742098808289 398 Epoch 94 iteration 200 loss 0.8791723251342773 399 Epoch 94 Training loss 1.1053885063813342 400 Epoch 95 iteration 0 loss 0.899246096611023 401 Epoch 95 iteration 100 loss 0.6236768364906311 402 Epoch 95 iteration 200 loss 0.8661567568778992 403 Epoch 95 Training loss 1.0993307278503 404 Evaluation loss 3.5332032706941585 405 Epoch 96 iteration 0 loss 0.8837733864784241 406 Epoch 96 iteration 100 loss 0.5473974943161011 407 Epoch 96 iteration 200 loss 0.9025910496711731 408 Epoch 96 Training loss 1.0998253373283113 409 Epoch 97 iteration 0 loss 0.922965407371521 410 Epoch 97 iteration 100 loss 0.5556969046592712 411 Epoch 97 iteration 200 loss 0.9027858972549438 412 Epoch 97 Training loss 1.096842199480861 413 Epoch 98 iteration 0 loss 0.8947715759277344 414 Epoch 98 iteration 100 loss 0.5312948822975159 415 Epoch 98 iteration 200 loss 0.9379984736442566 416 Epoch 98 Training loss 1.0949072217540066 417 Epoch 99 iteration 0 loss 0.8829227685928345 418 Epoch 99 iteration 100 loss 0.5451477766036987 419 Epoch 99 iteration 200 loss 0.8783729672431946 420 Epoch 99 Training loss 1.092216717956385
1 for i in range(100,120): 2 translate_dev(i) 3 print()
执行结果如下:
1 BOS you have nice skin . EOS 2 BOS 你 的 皮 膚 真 好 。 EOS 3 你有足球的食物都好了 4 5 BOS you ‘re UNK correct . EOS 6 BOS 你 部 分 正 确 。 EOS 7 你是个好厨师。 8 9 BOS everyone admired his courage . EOS 10 BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS 11 他們每個人都很好奇。 12 13 BOS what time is it ? EOS 14 BOS 几 点 了 ? EOS 15 它是什麼? 16 17 BOS i ‘m free tonight . EOS 18 BOS 我 今 晚 有 空 。 EOS 19 我今晚沒有空。 20 21 BOS here is your book . EOS 22 BOS 這 是 你 的 書 。 EOS 23 你這附書是讀書。 24 25 BOS they are at lunch . EOS 26 BOS 他 们 在 吃 午 饭 。 EOS 27 他们在吃米饭。 28 29 BOS this chair is UNK . EOS 30 BOS 這 把 椅 子 很 UNK 。 EOS 31 這是真的最好的。 32 33 BOS it ‘s pretty heavy . EOS 34 BOS 它 真 重 。 EOS 35 它很有。 36 37 BOS many attended his funeral . EOS 38 BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS 39 仔细把他的罪恶着。 40 41 BOS training will be provided . EOS 42 BOS 会 有 训 练 。 EOS 43 克林變得餓了。 44 45 BOS someone is watching you . EOS 46 BOS 有 人 在 看 著 你 。 EOS 47 有人在信封信。 48 49 BOS i slapped his face . EOS 50 BOS 我 摑 了 他 的 臉 。 EOS 51 我是他的兄弟。 52 53 BOS i like UNK music . EOS 54 BOS 我 喜 歡 流 行 音 樂 。 EOS 55 我喜歡音樂。 56 57 BOS tom had no children . EOS 58 BOS T o m 沒 有 孩 子 。 EOS 59 Tom沒有太累。 60 61 BOS please lock the door . EOS 62 BOS 請 把 門 鎖 上 。 EOS 63 請關門。 64 65 BOS tom has calmed down . EOS 66 BOS 汤 姆 冷 静 下 来 了 。 EOS 67 汤姆向伤极了。 68 69 BOS please speak more loudly . EOS 70 BOS 請 說 大 聲 一 點 兒 。 EOS 71 請講話。 72 73 BOS keep next sunday free . EOS 74 BOS 把 下 周 日 空 出 来 。 EOS 75 下周举一直流出席。 76 77 BOS i made a mistake . EOS 78 BOS 我 犯 了 一 個 錯 。 EOS 79 我有一个梦意。
翻译结果依然一般。
Pytorch-seq2seq机器翻译模型(不含attention和含attention两个版本)
标签:分割 sum target model 数据 汉字 线性 lis 假设
原文地址:https://www.cnblogs.com/cxq1126/p/13565961.html