码迷,mamicode.com
首页 > 其他好文 > 详细

deep_learning_Function_rnn_cell.BasicLSTMCell

时间:2019-10-08 14:03:03      阅读:78      评论:0      收藏:0      [点我收藏+]

标签:NPU   pytho   init   get   返回   ali   dash   rand   state   

tf.nn.rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True): n_hidden表示神经元的个数,forget_bias就是LSTM们的忘记系数,如果等于1,就是不会忘记任何信息。如果等于0,就都忘记。state_is_tuple默认就是True,官方建议用True,就是表示返回的状态用一个元祖表示。这个里面存在一个状态初始化函数,就是zero_state(batch_size,dtype)两个参数。batch_size就是输入样本批次的数目,dtype就是数据类型。

例如:

import tensorflow as tf
 
batch_size = 4
input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True)
init_state = cell.zero_state(batch_size, dtype=tf.float32)

output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major=True) #time_major如果是True,就表示RNN的steps用第一个维度表示,建议用这个,运行速度快一点。
#如果是False,那么输入的第二个维度就是steps。
#如果是True,output的维度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和输入是一样的

#final_state就是整个LSTM输出的最终的状态,包含c和h。c和h的维度都是[batch_size, n_hidden]
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(output))
    print(sess.run(final_state))

————————————————
原文链接:https://blog.csdn.net/UESTC_C2_403/article/details/73353145

deep_learning_Function_rnn_cell.BasicLSTMCell

标签:NPU   pytho   init   get   返回   ali   dash   rand   state   

原文地址:https://www.cnblogs.com/0405mxh/p/11634844.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!