【动手深度学习-笔记】注意力机制(三)多头注意力

在实践中,我们希望模型可以基于相同的注意力机制学习到不同的行为,抽取不同的信息(比如长距离依赖关系和短距离依赖关系),再将这些信息组合起来。
为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的h hh组不同的线性投影(linear projections)来变换查询、键和值,并行地送入到注意力汇聚,再将h hh组汇聚结果拼接到一起。
这种设计被称作多头注意力(multihead attention),融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
这里的“头”的概念类似于“通道”的概念,一个头表示一个注意力汇聚。

多个头连结然后线性变换

模型

给定查询q ∈ R d q \mathbf{q} \in \mathbb{R}^{d_q}qRdq,键k ∈ R d k \mathbf{k} \in \mathbb{R}^{d_k}kRdk,值v ∈ R d v \mathbf{v} \in \mathbb{R}^{d_v}vRdv,每个注意力头h i ( i = 1 , … , h ) \mathbf{h}_i(i = 1, \ldots, h)hi(i=1,,h)的计算方法:
h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v , (1) \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},\tag{1}hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv,(1)
原本q \mathbf{q}qk \mathbf{k}kv \mathbf{v}v的维度分别为d q d_qdqd k d_kdkd v d_vdv,经过W i \mathbf{W}_iWi线性投影为p q p_qpqp k p_kpkp v p_vpv维度;
输入到注意力汇聚函数f ff得到汇聚结果h i ∈ R p v \mathbf{h}_i\in \mathbb{R}^{p_v}hiRpvf ff可以是加性注意力或者是缩放点积注意力等;
然后将得到的h i \mathbf{h}_ihi组合,经过W o ∈ R p o × h p v \mathbf W_o\in\mathbb R^{p_o\times hp_v}WoRpo×hpv进行另一个线性转换,得到最终输出:
W o [ h 1 ⋮ h h ] ∈ R p o . \begin{split}\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.\end{split}Woh1hhRpo.

其中的可学习参数为W i ( q ) ∈ R p q × d q \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}Wi(q)Rpq×dqW i ( k ) ∈ R p k × d k \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}Wi(k)Rpk×dkW i ( v ) ∈ R p v × d v \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}Wi(v)Rpv×dvW o ∈ R p o × h p v \mathbf W_o\in\mathbb R^{p_o\times hp_v}WoRpo×hpv

每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

参考

10.5. 多头注意力 — 动手学深度学习 2.0.0-beta1 documentation


版权声明:本文为qq_41129489原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。