油封
2023/01/11阅读:31主题:默认主题
循环神经网络RNN&LSTM推导及实现
1. 从神经网络谈起
了解神经网络的都知道,神经网络作为一种非线性模型,在监督学习领域取得了state-of-art的效果,其中反向传播算法的提出居功至伟,到如今仍然是主流的优化神经网络参数的算法. 递归神经网络、卷积神经网络以及深度神经网络作为人工神经网络的"变种",仍然延续了ANN的诸多特质,如权值连接,激励函数,以神经元为计算单元等,只不过因为应用场景的不同衍生了不同的特性,如:处理变长数据、权值共享等。
为了介绍RNN,先简单的介绍ANN. ANN的结构很容易理解,一般是三层结构(输入层-隐含层-输出层). 隐含层输出 和输出层输出 如下。其中 为隐含层第 个神经元的输入, 为输入层和隐含层的连接权值矩阵, 为隐含层和输出层之间的连接权值矩阵.
定义损失函数为
从对
反向传播的实质是基于梯度下降的优化方法,只不过在优化的过程使用了一种更为优雅的权值更新方式。
2. 循环神经网络
传统的神经网络一般都是全连接结构,且非相邻两层之间是没有连接的。一般而言,定长输入的样本很容易通过神经网络来解决,但是类似于NLP中的序列标注这样的非定长输入,前向神经网络却无能为力。
循环神经网络(Recurrent Neural Network, RNN)可以解决这类问题,RNN一般认为是网络隐层节点之间有相互作用的连接,其实质可以认为是多个具有相同结构和参数的前向神经网络的stacking, 前向神经网络的数目和输入序列的长度一致,且序列中毗邻的元素对应的前向神经网络的隐层之间有互联结构,其图示( 图片来源 )如下.
上图只是一个比较抽象的结构,下面是一个以时间展开的更为具体的结构(图片来源).
从图中可以看出,输出层神经元的输入输出和前向神经网络中没有什么差异,仅仅在于隐层除了要接收输入层的输入外,还需要接受来自于自身的输入(可以理解为t时刻的隐层需要接收来自于t-1时刻隐层的输入, 当然这仅限于单向RNN的情况,在双向RNN还需要接受来自t+1时刻的输入).
RNN的隐层是控制信息传递的重要单元,不同时刻隐层之间的连接权值决定了过去时刻对当前时刻的影响,所以会存在时间跨度过大而导致这种影响会削弱甚至消失的现象,称之为梯度消失,改进一般都是针对隐层做文章,LSTM(控制输入量,补充新的信息然后输出),GRU(更新信息然后输出)等都是这类的改进算法.
下图为某时刻隐层单元的结构示意图(图片来源).
虽说处理的是不定长输入数据,但是某个时刻的输入还是定长的。令t时刻:输入
3. RNN推导
令隐含层的激励函数
对于单个样本,定义我们的cost function

令
所以不难得到误差方向传递时的权值计算方式:
写成矩阵的形式即是(其中
上述 公式其实在了解反向传播之后就能够很容易推导,对于
4.RNN实现
推导出了公式,用代码实现就很简单了,可以看到前向传递和反向传播过程中的都用到了矩阵的相关运算,封装好一个矩阵计算的类,然后重载相关的运算符,便可以十分方便的进行相关的计算了。矩阵类的实现见https://github.com/kymo/SUDL/blob/master/matrix/matrix.h ,然后直接按照上述的矩阵运算方式填写代码就可以了,详见:https://github.com/kymo/SUDL/blob/master/rnn/rnn.cpp.
5. LSTM推导
相比于普通的RNN而言,LSTM无非是多了几个门结构,求导的过程略微繁琐一点,但是只要清楚当前待求导变量到输出的路径,变可以依据链式求导法则得到对应的求导公式。LSTM是为了缓解RNN的梯度弥散问题而提出的一种变种模型,之所以能够缓解这个问题,是因为在原始的RNN中,梯度的传递是乘法的过程,如果梯度很小,那么从T时刻传递到后面的梯度只会越来越小,甚至消失,在优化空间中相当于一部分参数进行更新,而另外一部分参数几乎不变,那么问题的较优解也就很难收敛到。而LSTM通过推导会发现,梯度是以一种累加的方式进行反向传递的,从而一定程度上客服了累乘导致的梯度弥散的问题。下面要推导的LSTM 是不加peelhole的结构。要推导LSTM首先要了解LSTM的几个门结构:输入门、输出门以及遗忘门。从传统的神经网络的角度来看的话,三种门分别对应了三个并行的神经网络层,如下图所示(其中蓝色和黄色表示神经元层结构,有输入、输出和相应的激励函数;红色表示向量对应项相乘;淡绿色表示向量加)
三种门的计算方式如下:
另外,LSTM中也引入一个Cell State的中间状态,该状态有选择的保存了过去的历史信息,不会像RNN一样存在越久远的记忆越模糊的问题。该Cell State的更新方式主要是依赖于过去时刻的Cell State(
遗忘门用来控制过去时刻的Cell State有多少需要被保留,输入门用来控制增加多少新的信息到Cell State中,公式表述如下:
另外我们也需要定义两个新的中间变量
so, 另外一波公式来袭~~