标签:ret set sts range mbed embed rect ignore 计算
在循环内加的teacher forcing机制,这种为目标确定的时候,可以这样加。
目标不确定,需要在循环外加。
decoder.py 中的修改
""" 实现解码器 """ import torch.nn as nn import config import torch import torch.nn.functional as F import numpy as np import random class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() self.embedding = nn.Embedding(num_embeddings=len(config.ns), embedding_dim=50, padding_idx=config.ns.PAD) # 需要的hidden_state形状:[1,batch_size,64] self.gru = nn.GRU(input_size=50, hidden_size=64, num_layers=1, bidirectional=False, batch_first=True, dropout=0) # 假如encoder的hidden_size=64,num_layer=1 encoder_hidden :[2,batch_sizee,64] self.fc = nn.Linear(64, len(config.ns)) def forward(self, encoder_hidden,target): # 第一个时间步的输入的hidden_state decoder_hidden = encoder_hidden # [1,batch_size,encoder_hidden_size] # 第一个时间步的输入的input batch_size = encoder_hidden.size(1) decoder_input = torch.LongTensor([[config.ns.SOS]] * batch_size).to(config.device) # [batch_size,1] # print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size] decoder_outputs = torch.zeros([batch_size, config.max_len, len(config.ns)]).to(config.device) for t in range(config.max_len): decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值 value, index = decoder_output_t.max(dim=-1) if random.randint(0,100) >70: #teacher forcing机制 decoder_input = target[:,t].unsqueeze(-1) else: decoder_input = index.unsqueeze(-1) # [batch_size,1] # print("decoder_input:",decoder_input.size()) return decoder_outputs, decoder_hidden def forward_step(self, decoder_input, decoder_hidden): ‘‘‘ 计算一个时间步的结果 :param decoder_input: [batch_size,1] :param decoder_hidden: [batch_size,encoder_hidden_size] :return: ‘‘‘ decoder_input_embeded = self.embedding(decoder_input) # print("decoder_input_embeded:",decoder_input_embeded.size()) out, decoder_hidden = self.gru(decoder_input_embeded, decoder_hidden) # out :【batch_size,1,hidden_size】 out_squeezed = out.squeeze(dim=1) # 去掉为1的维度 out_fc = F.log_softmax(self.fc(out_squeezed), dim=-1) # [bathc_size,vocab_size] # out_fc.unsqueeze_(dim=1) #[bathc_size,1,vocab_size] # print("out_fc:",out_fc.size()) return out_fc, decoder_hidden def evaluate(self, encoder_hidden): # 第一个时间步的输入的hidden_state decoder_hidden = encoder_hidden # [1,batch_size,encoder_hidden_size] # 第一个时间步的输入的input batch_size = encoder_hidden.size(1) decoder_input = torch.LongTensor([[config.ns.SOS]] * batch_size).to(config.device) # [batch_size,1] # print("decoder_input:",decoder_input.size()) # 使用全为0的数组保存数据,[batch_size,max_len,vocab_size] decoder_outputs = torch.zeros([batch_size, config.max_len, len(config.ns)]).to(config.device) decoder_predict = [] # [[],[],[]] #123456 ,targe:123456EOS,predict:123456EOS123 for t in range(config.max_len): decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) decoder_outputs[:, t, :] = decoder_output_t # 获取当前时间步的预测值 value, index = decoder_output_t.max(dim=-1) decoder_input = index.unsqueeze(-1) # [batch_size,1] # print("decoder_input:",decoder_input.size()) decoder_predict.append(index.cpu().detach().numpy()) # 返回预测值 decoder_predict = np.array(decoder_predict).transpose() # [batch_size,max_len] return decoder_outputs, decoder_predict
seq2seq.py
""" 完成seq2seq模型 """ import torch.nn as nn from encoder import Encoder from decoder import Decoder class Seq2Seq(nn.Module): def __init__(self): super(Seq2Seq, self).__init__() self.encoder = Encoder() self.decoder = Decoder() def forward(self, input, input_len,target): encoder_outputs, encoder_hidden = self.encoder(input, input_len) decoder_outputs, decoder_hidden = self.decoder(encoder_hidden,target) return decoder_outputs def evaluate(self, input, input_len): encoder_outputs, encoder_hidden = self.encoder(input, input_len) decoder_outputs, decoder_predict = self.decoder.evaluate(encoder_hidden) return decoder_outputs, decoder_predict
train.py
""" 进行模型的训练 """ import torch import torch.nn.functional as F from seq2seq import Seq2Seq from torch.optim import Adam from dataset import get_dataloader from tqdm import tqdm import config import numpy as np import pickle from matplotlib import pyplot as plt from eval import eval import os model = Seq2Seq().to(config.device) optimizer = Adam(model.parameters()) if os.path.exists("./models/model.pkl"): model.load_state_dict(torch.load("./models/model.pkl")) optimizer.load_state_dict(torch.load("./models/optimizer.pkl")) loss_list = [] def train(epoch): data_loader = get_dataloader(train=True) bar = tqdm(data_loader, total=len(data_loader)) for idx, (input, target, input_len, target_len) in enumerate(bar): input = input.to(config.device) target = target.to(config.device) input_len = input_len.to(config.device) optimizer.zero_grad() decoder_outputs = model(input, input_len,target) # [batch_Size,max_len,vocab_size] predict = decoder_outputs.view(-1, len(config.ns)) target = target.view(-1) loss = F.nll_loss(predict, target, ignore_index=config.ns.PAD) loss.backward() optimizer.step() loss_list.append(loss.item()) bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch, idx, np.mean(loss_list))) if idx % 100 == 0: torch.save(model.state_dict(), "./models/model.pkl") torch.save(optimizer.state_dict(), "./models/optimizer.pkl") pickle.dump(loss_list, open("./models/loss_list.pkl", "wb")) if __name__ == ‘__main__‘: for i in range(5): train(i) eval() plt.figure(figsize=(50, 8)) plt.plot(range(len(loss_list)), loss_list) plt.show()
pytorch seq2seq模型中加入teacher_forcing机制
标签:ret set sts range mbed embed rect ignore 计算
原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12343829.html