首页 > 其他好文 > 详细

Lstm Cell in detail and how to implement it by pytorch

时间:2018-07-01 10:27:37      阅读:331      评论:0      收藏:0      [点我收藏+]

标签:medium   near   mes   bee   val   cuda   eterm   image   none   

Refer to :






LSTM cells in PyTorch

This is an annotated illustration of the LSTM cell in PyTorch (admittedly inspired by the diagrams in Christopher Olah’s excellent blog article):







The yellow boxes correspond to matrix multiplication followed by non-linearities. W represent the weight matrices, the bias terms b have been omitted for simplicity. The mathematical symbols used in this diagram correspond to those used in PyTorch’s documentation of torch.nn.LSTM:

  • x(t): the external input (e.g. from training data) at time t
  • h(t-1)/h(t): the hidden state at times t-1 (‘input’) or t (‘output’). Despite its name, this is also used as output or used as input for a next layer of LSTM cells (for multi-layer networks)
  • c(t-1)/c(t): the ‘cell state’ or ‘memory’ at times t-1 and t
  • f(t): the result of the forget gate. For values close to zero the cell will ‘forget’ its memories c(t-1) from the past, for values close to one it will remember its history.
  • i(t): the result of the input gate, determining how important the (transformed) new external input is.
  • g(t): the result of the cell gate, a non-linear transformation of the new external input x(t)
  • o(t): the result of the output gate which controls how much of the new cell state c(t) should go to the output (and the hidden state)

It is also instructive to look at the implementation of torch.nn._functions.rnn.LSTMCell :

def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
    if input.is_cuda:

    h_t_1, c_t_1 = hidden
    gates = F.linear(input, w_ih, b_ih) + F.linear(h_t_1, w_hh, b_hh)

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate     = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate   = F.tanh(cellgate)
    outgate    = F.sigmoid(outgate)

    c_t = (forgetgate * c_t_1) + (ingate * cellgate)
    h_t = outgate * F.tanh(c_t)

    return h_t, c_t


The second argument (hidden) in fact is expected to be a tuple of: (ht-1, ct-1)

(hidden state at time t-1, cell/memory state at time t-1)

and the return value is of the same format but for time t.


Lstm Cell in detail and how to implement it by pytorch

标签:medium   near   mes   bee   val   cuda   eterm   image   none   


评论 一句话评论(0
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com