标签:rnn 时序数据 deep learning
RNN(recurrent neural network)是神经网络的一种,主要用于时序数据的分析,预测,分类等。
RNN的general介绍请见下一篇文章《Deep learning From Image to Sequence》。本文针对对deep learning有一点基础(神经网络基本training原理,RBM结构及原理,简单时序模型)的小伙伴讲一下Bengio一个工作(RNNRBM)的原理和实现。
本文重点内容:针对RNN(recurrent neural network)一个应用:music composition进行架构和程序解读,参见paper:Modeling Temporal Dependencies in high-dimensional Sequences。
-----------------------------------------------------------------
Content:
1. RNN 的 general 架构及思想
2. RNN-RBM的定义
2.1 RTRBM结构
2.2 RNN-RBM网络架构
2.3 RNN-RBM的训练
3. RNN-RBM的实现及程序解读
-----------------------------------------------------------------
1. RNN 的 general 架构及思想
RNN是处理时序数据的NN模型,旨在建立时序数据模型,做模拟/预测/分类等。
fig1. Architecture of RNN
如上图A所示为RNN的基本结构,简单的说RNN就是由input units (u), internal units (x), output units (y)组成的neural network. 其中internal units 层内会有连接形成环。Intuitively,这样做的目的是希望让网络下一时刻的状态与当前时刻相关,即,一个有记忆的网络。
展开!如图B,是在t-1, t, t+1时刻network的参数传递(仅展示出forward-propagation中节点间相互decision情况)
2. RNN-RBM的定义:
2.1 RTRBM结构
首先我们以RTRBM(RNNRBM简化版,由Sutskever于08年提出)介入。以下是RTRBM的结构:
图中每一个红框框住了一个RBM,h是hidden states,v是visible nodes,比如表示为某一时刻的语音等(但实际上为了增加维度有些工作会把v(t)扩展为前后共n帧data的value),双向箭头表示h和v生成的条件概率,即:
(1)
其中σ是sigmoid函数。
对于每个时刻的RBM,v和h的联合概率分布为:
(2)
其中A(t)=,即所有t时刻之前的{v,h}集合。
此外对于RTRBM,可以理解为每个时刻可以由上一时刻的状态h(t-1)对该时刻产生影响(通过W‘和W‘‘),然后通过RBM得到一个(h(t),v(t))稳态。由于每一个参数都和上一时刻的参数有关,我们可以认为只有bias项是受hidden影响的,这样效果是一样的,即:
(2)
2.2 RNN-RBM网络架构
看到了RTRBM这个结构,bengio他们就想了,RTRBM结构里hidden layer描述的是visible的条件概率分布,只能保存暂时的信息(他应该指的是达到稳态后),那我能不能把这些rbm里的hidden layer用RNN代替?于是就冒出了RNN-RBM:
其中每个红框依然框住一个RBM,而下面绿框就表示了一个按时间展开了的RNN。这样设计的好处是把hiddenlayer分离了,一部分(h)只用于表示当前RBM的稳态state,另一部分(h^)表示RNN里的hidden节点。
PS: 关于RNN的网络结构:v(visible),u(internal units),h(hidden)
边:v-u, u-v, v-h(双向边,==h-v), u-h, u-u(实际上是环,只不过时序模型中unfold成u^t-u^{t+1})2.3 RNN-RBM 的训练
1. 由计算h^
2. 由(2)计算bh, bv,并根据k-step block gibbs采样得到v(t)
3. 通过NLL的cost对RBM里的参数(W, bh, bv)进行求导并更新
4. 估计RNN参数(W2, W3, bh^)并进行更新
3. RNN-RBM的实现及程序解读
3.1 准备工作及环境配置
3.1.1 参考程序见 : http://deeplearning.net/tutorial/rnnrbm.html
3.1.2 下载midi包(http://www.iro.umontreal.ca/~lisa/deep/midi.zip),extract到python包目录下(我的是/usr/lib/python2.7/dist-packages)
3.1.3 下载数据集(Nottingham Database of folk tunes),放在代码同文件夹下的data/
3.2 程序关键点解读:
1. build_rbm: 构建单个RBM, 进行k次vhv采样
输入:5个参数v(visible), W(RBM weight), bv(v_bias), bh(h_bias), k(param k in CD-k)RNN-RBM for music composition 网络架构及程序解读
标签:rnn 时序数据 deep learning
原文地址:http://blog.csdn.net/abcjennifer/article/details/27709915