码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch seq2seq模型示例

时间:2019-11-07 13:38:58      阅读:261      评论:0      收藏:0      [点我收藏+]

标签:bcd   coder   target   http   app   man   div   line   col   

以下代码可以让你更加熟悉seq2seq模型机制

"""
    test
"""
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

# 创建字典
seq_data = [[man, women], [black, white], [king, queen], [girl, boy], [up, down], [high, low]]
char_arr = [c for c in SEPabcdefghijklmnopqrstuvwxyz]
num_dict = {n:i for i,n in enumerate(char_arr)}

# 网络参数
n_step = 5
n_hidden = 128
n_class = len(num_dict)
batch_size = len(seq_data)

# 准备数据
def make_batch(seq_data):
    input_batch, output_batch, target_batch =[], [], []

    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + P * (n_step-len(seq[i]))
        input = [num_dict[n] for n in seq[0]]
        ouput = [num_dict[n] for n in (S+ seq[1])]
        target = [num_dict[n] for n in (seq[1]) + E]

        input_batch.append(np.eye(n_class)[input])
        output_batch.append(np.eye(n_class)[ouput])
        target_batch.append(target)

    return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))

input_batch, output_batch, target_batch = make_batch(seq_data)


# 创建网络
class Seq2Seq(nn.Module):
    """
    要点:
    1.该网络包含一个encoder和一个decoder,使用的RNN的结构相同,最后使用全连接接预测结果
    2.RNN网络结构要熟知
    3.seq2seq的精髓:encoder层生成的参数作为decoder层的输入
    """
    def __init__(self):
        super().__init__()
        # 此处的input_size是每一个节点可接纳的状态,hidden_size是隐藏节点的维度
        self.enc = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.dec = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
        self.fc = nn.Linear(n_hidden, n_class)

    def forward(self, enc_input, enc_hidden, dec_input):
        # RNN要求输入:(seq_len, batch_size, n_class),这里需要转置一下
        enc_input = enc_input.transpose(0,1)
        dec_input = dec_input.transpose(0,1)
        _, enc_states = self.enc(enc_input, enc_hidden)
        outputs, _ = self.dec(dec_input, enc_states)
        pred = self.fc(outputs)

        return pred


# training
model = Seq2Seq()
loss_fun = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5000):
    hidden = Variable(torch.zeros(1, batch_size, n_hidden))

    optimizer.zero_grad()
    pred = model(input_batch, hidden, output_batch)
    pred = pred.transpose(0, 1)
    loss = 0
    for i in range(len(seq_data)):
        temp = pred[i]
        tar = target_batch[i]
        loss +=  loss_fun(pred[i], target_batch[i])
    if (epoch + 1) % 1000 == 0:
        print(Epoch: %d   Cost: %f % (epoch + 1, loss))
    loss.backward()
    optimizer.step()


# 测试
def translate(word):
    input_batch, output_batch, _ = make_batch([[word, P * len(word)]])
    # hidden 形状 (1, 1, n_class)
    hidden = Variable(torch.zeros(1, 1, n_hidden))
    # output 形状(6,1, n_class)
    output = model(input_batch, hidden, output_batch)
    predict = output.data.max(2, keepdim=True)[1]
    decoded = [char_arr[i] for i in predict]
    end = decoded.index(E)
    translated = ‘‘.join(decoded[:end])

    return translated.replace(P, ‘‘)

print(girl ->, translate(girl))

参考:https://blog.csdn.net/weixin_43632501/article/details/98525673

pytorch seq2seq模型示例

标签:bcd   coder   target   http   app   man   div   line   col   

原文地址:https://www.cnblogs.com/demo-deng/p/11811090.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!