LSTM 入门级解读

记录学习过程,方便日后查用。本贴包括数学计算过程和模型解读。

如有错误请指出,感谢大家的指导。

图片来源 LSTM模型结构的可视化 - 知乎

淡绿色的方块被称为cell,是构成LSTM的主要结构。实际上对于RNN类网络来说,都会有一个这样的结构块,在时间上循环这个结构块就构成了RNN网络。上图是最基础的LSTM网络。

LSTM的单元输入总共有3个部分 h是隐藏层,X是数据输入,C可以看成是网络的记忆部分。所有红色的单元是运算符,运算过程就是简单的套用运算符;所有黄色的单元是网络层,运算过程类似感知机,sigma符号代表的激活函数默认为sigmoid函数。

数学运算过程

LSTM的cell内部总共有3个主要的门,第一个被称为忘记门(forget gate),用来决定上一轮的输入能有多少影响到这一轮的输入。

忘记门公式

F_a = \sigma(W_f*[x_{t},h_{t-1}]+b_f)

F_t = C_{t-1}*F_a

中括号表示concatenate,单纯的将两个向量进行维度上的合并,如x有100维,h有200维,那么中括号就会返回一个300维的向量。忘记门会对上一轮的输入做一个筛选,和输入门的输入一起做加法得到本轮的记忆。

输入门决定了这一轮的主要输入。

Ia= \sigma(W_{i}*[x_{t},h_{t-1}]+b_{i})

\widetilde{C_t} = tanh(W_{c})*[x_{t},h_{t-1}]+b_c

C_t = I_{a}*\widetilde{C_t}+F_t

输出门

O_a = \sigma(W_o*[x_{t},h_{t-1}]+b_o)

h_t = O_a*tanh(C_t)

至此,我们得到了本轮输出C_t,h_t

模型简单解读

LSTM能拥有长时记忆的主要原因就在于变量C,C的运算结构中包含了加法。对比传统的RNN网络只有一个tanh来说,更不容易出现梯度爆炸或者梯度消失的情况。

LSTM的参数个数计算。假设词向量的维度是m,隐藏层维度为n。

那么参数总数为((m+n)*n+n)*4。

上文提到LSTM虽然是链式结构,但是是在时间上循环同一个单元,所以cell之间所有的参数是共享的,不共享的是每个cell内4个网络层的参数。4个网络层都是感知机的模式,相当于4个全连接层。全连接层的输入维度是xh的concatenate,输出维度是h,再加上偏置和输出维度相同,所以参数数量一共是 (m+n)*n+n,见下图。因为总共有4个这样的网络,所以再乘以4。

题外话

一般在使用LSTM做文本数据时,我们关注的不仅仅只是过去的信息,可能还会有未来的信息,即“结合上下文”。所以会用上双向的LSTM,双向LSTM可以看成LSTM和他的镜像结合在一起,最终两个LSTM的隐层结合一下再输出。所以参数个数是LSTM的两倍

LSTM还有很多的变种,但大体的结构都大差不大。


版权声明:本文为wuhao1205原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。