深度学习之循环神经网络(10)GRU简介

深度学习之循环神经网络(10)GRU简介


LSTM具有更长的记忆能力,在大部分序列任务上面都取得了比基础RNN模型更好的性能表现,更重要的是,LSTM不容易出现梯度弥散现象。但是LSTM结构相对较复杂,计算代价较高,模型参数量较大。因此科学家们尝试简化LSTM内部的计算流程,特别是减少门控数量。研究发现,遗忘门是LSTM中最重要的门控[1],甚至发现只有遗忘门的简化版网络在多个基准数据集上面优于标准LSTM网络。在众多的简化版LSTM中, 门控循环网络(Gated Recurrent Unit,简称GRU)是应用最广泛的RNN变种之一。GRU把内部状态向量和输出向量合并,统一为状态向量 h \boldsymbol hh,门控数量也较少到2个: 复位门(Reset Gate)更新门(Update Gate),如下图所示:

GRU网络结构


下面我们来分别介绍复位门和更新门的原理与功能。


[1]J. Westhuizen 和 J. Lasenby, “The unreasonable effectiveness of the forget gate,” CoRR, 卷 abs/1804.04849, 2018.


1. 复位门

复位门用于控制上一个时间戳的状态h t − 1 \boldsymbol h_{t-1}ht1进入GRU的量。门控向量g r \boldsymbol g_rgr由当前时间戳输入x t \boldsymbol x_txt和上一时间戳状态h t − 1 \boldsymbol h_{t-1}ht1变换得到,关系如下:
g r = σ ( W r [ h t − 1 , x t ] + b r ) \boldsymbol g_r=σ(\boldsymbol W_r [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_r)gr=σ(Wr[ht1,xt]+br)
其中W r \boldsymbol W_rWrb r \boldsymbol b_rbr为复位门的参数,由反向传播算法自动优化,σ σσ为激活函数,一般使用Sigmoid函数。门控向量g r = 0 \boldsymbol g_r=0gr=0时,新输入h ~ t \tilde \boldsymbol h_th~t全部来自于输入x t \boldsymbol x_txt,不接受h t − 1 \boldsymbol h_{t-1}ht1,此时相当于复位h t − 1 \boldsymbol h_{t-1}ht1。当g r = 1 \boldsymbol g_r=1gr=1时,h t − 1 h_{t-1}ht1和输入x t \boldsymbol x_txt共同产生新输入h ~ t \tilde\boldsymbol h_th~t,如下图所示:

复位门

2. 更新门

更新门用控制上一时间戳状态h t − 1 \boldsymbol h_{t-1}ht1和新输入h ~ t \tilde\boldsymbol h_th~t对新状态向量h t \boldsymbol h_tht的影响程度。更新门控向量g z \boldsymbol g_zgz
g z = σ ( W z [ h t − 1 , x t ] + b z ) \boldsymbol g_z=σ(\boldsymbol W_z [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_z)gz=σ(Wz[ht1,xt]+bz)
得到,其中W z \boldsymbol W_zWzb z \boldsymbol b_zbz为更新门的参数,由反向传播算法自动优化,σ σσ为激活函数,一般使用Sigmoid函数。g z \boldsymbol g_zgz用于控制新输入h ~ t \tilde\boldsymbol h_th~t信号,1 − g z 1-\boldsymbol g_z1gz用于控制状态h t − 1 \boldsymbol h_{t-1}ht1信号:
h t = ( 1 − g z ) h t − 1 + g z h ~ t \boldsymbol h_t=(1-\boldsymbol g_z ) \boldsymbol h_{t-1}+\boldsymbol g_z \tilde\boldsymbol h_tht=(1gz)ht1+gzh~t

更新门


可以看到,h ~ t \tilde\boldsymbol h_th~th t − 1 \boldsymbol h_{t-1}ht1的更新量处于相互竞争、此消彼长的状态。当更新门g z = 0 \boldsymbol g_z=0gz=0时,h t \boldsymbol h_tht全部来自上一时间戳状态h t − 1 \boldsymbol h_{t-1}ht1;当更新门g z = 1 \boldsymbol g_z=1gz=1时,h t \boldsymbol h_tht全部来自新输入h ~ t \tilde\boldsymbol h_th~t


3. GRU使用方法

同样地,在TensorFlow中,也有Cell方式和层方式实现GRU网络。GRUCell和GRU层的使用方法和之前的SimpleRNNCell、LSTMCell、SimpleRNN和LSTM非常类似。首先是GRUCell的使用,创建GRU Cell对象,并在时间轴上循环展开运算。例如:

import tensorflow as tf
from tensorflow.keras import layers

x = tf.random.normal([2, 80, 100])
xt = x[:, 0, :]  # 得到一个时间戳的输入
# 初始化状态向量,GRU只有一个
h = [tf.zeros([2, 64])]
cell = layers.GRUCell(64)  # 新建GRU Cell,向量长度为64
# 在时间戳维度上解开,循环通过cell
for xt in tf.unstack(x, axis=1):
    out, h = cell(xt, h)
# 输出形状
print(out.shape)


运行结果如下所示:

(2, 64)


通过layers.GRU类可以方便创建一层GRU网络层,通过Sequential容器可以堆叠多层GRU层的网络。例如:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential

x = tf.random.normal([2, 80, 100])
xt = x[:, 0, :]  # 得到一个时间戳的输入
# 初始化状态向量,GRU只有一个
h = [tf.zeros([2, 64])]
net = keras.Sequential([
    layers.GRU(64, return_sequences=True),
    layers.GRU(64)
])
out = net(x)
# 输出形状
print(out.shape)


运行结果如下所示:

(2, 64)

版权声明:本文为weixin_43360025原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。