transformer中的attention为什么scaled?

原文链接:transformer中的attention为什么scaled?——LinT的回答

这个问题困扰良久,一直没研究清楚,只知道个大概,不知其所以然,这里专门开一篇总结一下。由于有人珠玉在前,写得极其精彩,所以直接转载了,以下为原文。

———————————————————————————————————————————————

谢邀。非常有意义的问题,我思考了好久,按照描述中的两个问题分点回答一下。

1. 为什么比较大的输入会使得softmax的梯度变得很小?

对于一个输入向量x ∈ R d \mathbf{x} \in \mathbb{R}^{d}xRd ,softmax函数将其映射/归一化到一个分布y ^ ∈ R d \hat{\mathbf{y}} \in \mathbb{R}^{d}y^Rd。在这个过程中,softmax先用一个自然底数e ee将输入中的元素间差距先“拉大”,然后归一化为一个分布。假设某个输入x xx中最大的的元素下标是k kk,如果输入的数量级变大(每个元素都很大),那么y ^ k \hat{y}_{k}y^k会非常接近1。

我们可以用一个小例子来看看x xx的数量级对输入最大元素对应的预测概率y ^ k \hat{y}_{k}y^k的影响。假定输入x = [ a , a , 2 a ] ⊤ \mathbf{x}=[a, a, 2 a]^{\top}x=[a,a,2a]),我们来看不同量级的a aa产生的y ^ 3 \hat{y}_{3}y^3有什么区别。

  • a = 1 a=1a=1时,y ^ 3 = 0.5761168847658291 \hat{y}_{3}=0.5761168847658291y^3=0.5761168847658291
  • a = 10 a=10a=10时,y ^ 3 = 0.999909208384341 \hat{y}_{3}=0.999909208384341y^3=0.999909208384341
  • a = 100 a=100a=100时,y ^ 3 ≈ 1.0 \hat{y}_{3} \approx 1.0y^31.0(计算精度限制);

我们不妨把a aa在不同取值下,对应的y ^ 3 \hat{y}_{3}y^3全部绘制出来。代码如下:

from math import exp
from matplotlib import pyplot as plt
import numpy as np 
f = lambda x: exp(x * 2) / (exp(x) + exp(x) + exp(x * 2))
x = np.linspace(0, 100, 100)
y_3 = [f(x_i) for x_i in x]
plt.plot(x, y_3)
plt.show()

得到的图如下所示:
在这里插入图片描述
可以看到,数量级对softmax得到的分布影响非常大。在数量级较大时,softmax将几乎全部的概率分布都分配给了最大值对应的标签。

然后我们来看softmax的梯度。不妨简记softmax函数为g ( ⋅ ) g(\cdot)g(),softmax得到的分布向量y ^ = g ( x ) \hat{\mathbf{y}}=g(\mathbf{x})y^=g(x)对输入x xx的梯度为:
∂ g ( x ) ∂ x = diag ⁡ ( y ^ ) − y ^ y ^ ⊤ ∈ R d × d \frac{\partial g(\mathbf{x})}{\partial \mathbf{x}}=\operatorname{diag}(\hat{\mathbf{y}})-\hat{\mathbf{y}} \hat{\mathbf{y}}^{\top} \quad \in \mathbb{R}^{d \times d}xg(x)=diag(y^)y^y^Rd×d 把这个矩阵展开:
∂ g ( x ) ∂ x = [ y ^ 1 0 ⋯ 0 0 y ^ 2 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ y ^ d ] − [ y ^ 1 2 y ^ 1 y ^ 2 ⋯ y ^ 1 y ^ d y ^ 2 y ^ 1 y ^ 2 2 ⋯ y ^ 2 y ^ d ⋮ ⋮ ⋱ ⋮ y ^ d y ^ 1 y ^ d y ^ 2 ⋯ y ^ d 2 ] \frac{\partial g(\mathbf{x})}{\partial \mathbf{x}}=\left[\begin{array}{cccc} \hat{y}_{1} & 0 & \cdots & 0 \\ 0 & \hat{y}_{2} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \hat{y}_{d} \end{array}\right]-\left[\begin{array}{cccc} \hat{y}_{1}^{2} & \hat{y}_{1} \hat{y}_{2} & \cdots & \hat{y}_{1} \hat{y}_{d} \\ \hat{y}_{2} \hat{y}_{1} & \hat{y}_{2}^{2} & \cdots & \hat{y}_{2} \hat{y}_{d} \\ \vdots & \vdots & \ddots & \vdots \\ \hat{y}_{d} \hat{y}_{1} & \hat{y}_{d} \hat{y}_{2} & \cdots & \hat{y}_{d}^{2} \end{array}\right]xg(x)=y^1000y^2000y^dy^12y^2y^1y^dy^1y^1y^2y^22y^dy^2y^1y^dy^2y^dy^d2 根据前面的讨论,当输入x xx的元素均较大时,softmax会把大部分概率分布分配给最大的元素,假设我们的输入数量级很大,最大的元素是x 1 x_1x1,那么就将产生一个接近one-hot的向量y ^ ≈ [ 1 , 0 , ⋯ , 0 ] ⊤ \hat{\mathbf{y}} \approx[1,0, \cdots, 0]^{\top}y^[1,0,,0],此时上面的矩阵变为如下形式:
∂ g ( x ) ∂ x ≈ [ 1 0 ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ 0 ] − [ 1 0 ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ 0 ] = 0 \frac{\partial g(\mathbf{x})}{\partial \mathbf{x}} \approx\left[\begin{array}{cccc} 1 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 0 \end{array}\right]-\left[\begin{array}{cccc} 1 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 0 \end{array}\right]=\mathbf{0}xg(x)100000000100000000=0 也就是说,在输入的数量级很大时,梯度消失为0,造成参数更新困难。

注: softmax的梯度可以自行推导,网络上也有很多推导可以参考。

2. 维度与点积大小的关系是怎么样的,为什么使用维度的根号来放缩?

针对为什么维度会影响点积的大小,在论文的脚注中其实给出了一点解释:
在这里插入图片描述假设向量q qqk kk的各个分量是互相独立的随机变量,均值是0,方差是1,那么点积q ⋅ k q \cdot kqk的均值是0,方差是d k d_kdk。这里我给出一点更详细的推导:

∀ i = 1 , ⋯ , d k \forall i=1, \cdots, d_{k}i=1,,dkq i q_iqik i k_iki都是随机变量,为了方便书写,不妨记 X = q i X=q_iX=qiY = k i Y=k_iY=ki。这样有:D ( X ) = D ( Y ) = 1 D(X)=D(Y)=1D(X)=D(Y)=1E ( X ) = E ( Y ) = 0 E(X)=E(Y)=0E(X)=E(Y)=0。则:

  1. E ( X Y ) = E ( X ) E ( Y ) = 0 × 0 = 0 E(X Y)=E(X) E(Y)=0 \times 0=0E(XY)=E(X)E(Y)=0×0=0
  2. D ( X Y ) = E ( X 2 ⋅ Y 2 ) − [ E ( X Y ) ] 2 = E ( X 2 ) E ( Y 2 ) − [ E ( X ) E ( Y ) ] 2 = E ( X 2 − 0 2 ) E ( Y 2 − 0 2 ) − [ E ( X ) E ( Y ) ] 2 = E ( X 2 − [ E ( X ) ] 2 ) E ( Y 2 − [ E ( Y ) ] 2 ) − [ E ( X ) E ( Y ) ] 2 = D ( X ) D ( Y ) − [ E ( X ) E ( Y ) ] 2 = 1 × 1 − ( 0 × 0 ) 2 = 1 \begin{aligned} D(X Y) &=E\left(X^{2} \cdot Y^{2}\right)-[E(X Y)]^{2} \\ &=E\left(X^{2}\right) E\left(Y^{2}\right)-[E(X) E(Y)]^{2} \\ &=E\left(X^{2}-0^{2}\right) E\left(Y^{2}-0^{2}\right)-[E(X) E(Y)]^{2} \\ &=E\left(X^{2}-[E(X)]^{2}\right) E\left(Y^{2}-[E(Y)]^{2}\right)-[E(X) E(Y)]^{2} \\ &=D(X) D(Y)-[E(X) E(Y)]^{2} \\ &=1 \times 1-(0 \times 0)^{2} \\ &=1 \end{aligned}D(XY)=E(X2Y2)[E(XY)]2=E(X2)E(Y2)[E(X)E(Y)]2=E(X202)E(Y202)[E(X)E(Y)]2=E(X2[E(X)]2)E(Y2[E(Y)]2)[E(X)E(Y)]2=D(X)D(Y)[E(X)E(Y)]2=1×1(0×0)2=1

这样∀ i = 1 , ⋯ , d k \forall i=1, \cdots, d_{k}i=1,,dkq i ⋅ k i q_i \cdot k_iqiki的均值是0,方差是1,又由期望和方差的性质, 对相互独立的分量z i z_izi,有
E ( ∑ i Z i ) = ∑ i E ( Z i ) E\left(\sum_{i} Z_{i}\right)=\sum_{i} E\left(Z_{i}\right)E(iZi)=iE(Zi)
以及
D ( ∑ i Z i ) = ∑ i D ( Z i ) D\left(\sum_{i} Z_{i}\right)=\sum_{i} D\left(Z_{i}\right)D(iZi)=iD(Zi)
所以有q ⋅ k q \cdot kqk的均值E ( q ⋅ k ) = 0 E(q \cdot k)=0E(qk)=0,方差D ( q ⋅ k ) = d k D(q \cdot k)=d_{k}D(qk)=dk。方差越大也就说明,点积的数量级越大(以越大的概率取大值)。那么一个自然的做法就是把方差稳定到1,做法是将点积除以d k \sqrt{d}_{k}dk,这样有:
D ( q ⋅ k d k ) = d k ( d k ) 2 = 1 D\left(\frac{q \cdot k}{\sqrt{d}_{k}}\right)=\frac{d_{k}}{\left(\sqrt{d}_{k}\right)^{2}}=1D(dkqk)=(dk)2dk=1 将方差控制为1,也就有效地控制了前面提到的梯度消失的问题。

可以参考一下。水平有限,如果有误请指出。