标签:ext 完成 状态 手写 根据 class ret next util
1 def rnn_cell_forward(xt, a_prev, parameters): 2 """ 3 根据图2实现RNN单元的单步前向传播 4 5 参数: 6 xt -- 时间步“t”输入的数据,维度为(n_x, m) 7 a_prev -- 时间步“t - 1”的隐藏隐藏状态,维度为(n_a, m) 8 parameters -- 字典,包含了以下内容: 9 Wax -- 矩阵,输入乘以权重,维度为(n_a, n_x) 10 Waa -- 矩阵,隐藏状态乘以权重,维度为(n_a, n_a) 11 Wya -- 矩阵,隐藏状态与输出相关的权重矩阵,维度为(n_y, n_a) 12 ba -- 偏置,维度为(n_a, 1) 13 by -- 偏置,隐藏状态与输出相关的偏置,维度为(n_y, 1) 14 15 返回: 16 a_next -- 下一个隐藏状态,维度为(n_a, m) 17 yt_pred -- 在时间步“t”的预测,维度为(n_y, m) 18 cache -- 反向传播需要的元组,包含了(a_next, a_prev, xt, parameters) 19 """ 20 21 # 从“parameters”获取参数 22 Wax = parameters["Wax"] 23 Waa = parameters["Waa"] 24 Wya = parameters["Wya"] 25 ba = parameters["ba"] 26 by = parameters["by"] 27 28 # 使用上面的公式计算下一个激活值 29 a_next = np.tanh(np.dot(Waa, a_prev) + np.dot(Wax, xt) + ba) 30 31 # 使用上面的公式计算当前单元的输出 32 yt_pred = rnn_utils.softmax(np.dot(Wya, a_next) + by) 33 34 # 保存反向传播需要的值 35 cache = (a_next, a_prev, xt, parameters) 36 37 return a_next, yt_pred, cache 38
|
实现一个RNN单元,这需要由以下几步完成:
|
标签:ext 完成 状态 手写 根据 class ret next util
原文地址:https://www.cnblogs.com/nxf-rabbit75/p/9943406.html