清华提出 DynamicViT,几乎不降低模型性能的情况下,较大地减少计算量和参数量

感觉其他人总结得不好,不详细,就是翻译了一下,因此自己仔细看了一遍并写了以下的总结。最后一部分的老师模型(teacher’s model)没有详细看,因为目前还没太看过蒸馏这一块,以后看了可以补齐。

文章目的:在几乎没有降低模型性能的情况下,较大地(30%以上)减少了参数量和计算量。

论文地址: https://arxiv.org/abs/2106.02034

摘要

本文是清华和 UCLA 一起提出的,关于对 Transformer 模型剪枝的优化方法,主要贡献如下:

  • 训练一个轻量级的网络,来判断输入的 token 中,哪些重要,哪些不重要;重要的保留,不重要的丢弃;
  • 提出一个注意力掩码策略,将需要丢弃的 token 与其它 token 之间的联系切断,并使得模型可导;
  • 接着用蒸馏方法,用 teacher 模型来训练一个学生模型,以达到进一步精简的目的;

在这里插入图片描述

简介

首先,提出一下轻量级的模型,来动态地决定谁可以哪个 token 可以被剪枝掉。具体来说就是,对每个输入,模型都会生成一个二元决定掩码,来判断哪个 token 没有太多信息,因此可以被丢弃。

这个模块可以加在多个 layer 上,这样就可以层次化地做稀疏化。

每当一个 token 被剪掉的时候,它不会再被传到之后的 feed-forward 层。

为了实现端到端的训练,使用了两个特殊的策略。

  1. 采用了 Gumbel-Softmax,来克服当从分布中采样后不可导的情况;
  2. 第二个是关于如何使用这个二元掩码来剪枝

如果直接将没太多信息的 token 丢弃,会带来并行训练时的一些问题(如果直接将几个 token 丢弃,每个 batch 里面的 token 数量就会不一致,因为并不是每张图片都会丢弃相同数量的 token)

因此,我们根据得到的二元掩码,将被丢弃的 token 与其它所有 token 在注意力矩阵中的连接丢弃。

此外,作者还改动了原始的训练目标函数。

在推断时,只需要将固定数量的 token 直接丢弃就行,不存在不可导的问题。
在这里插入图片描述

Dynamic Vision Transformers

简介

模型由一个普通的 Transformer backbone,和几个预测模块组成。

预测模块负责生成 token 需要被丢弃或者保存的概率。

预测头的分层 token 稀疏化

DynamicViT 一个很重要的特性是分层对 token 进行稀疏化,也就是会随着计算进行,将没有太多信息量的 token 丢弃。

为了达到这个目的,作者维持着一个二元决策掩码 D ^ ∈ { 0 , 1 } N \hat{\mathbf{D}} \in\{0,1\}^{N}D^{0,1}N,来决定是否丢弃每个 token。 此处 N = H W N = HWN=HW 表示 patch 的数量。二元决策掩码全部被初始化为1,且 class token 的掩码一直保持为1.

决策模块的输入是当前的决策掩码 D ^ \hat{\mathbf{D}}D^ 和所有输入的 token x ∈ R N × C \mathbf{x} \in \mathbb{R}^{N \times C}xRN×C,首先将 tokens 用 MLP 进行映射

z local  = M L P ( x ) ∈ R N × C ′ \mathbf{z}^{\text {local }}=\mathrm{MLP}(\mathbf{x}) \in \mathbb{R}^{N \times C^{\prime}}zlocal =MLP(x)RN×C

(上面的 C CC 应该是 embedding_size,比如768),C ′ C'C 可以是一个较小的尺寸,作者使用的是 C ′ = C / 2 C^{\prime}=C / 2C=C/2

接着,可以计算一个全局的特征:

z global  = Agg ⁡ ( MLP ⁡ ( x ) , D ) ∈ R C ′ \mathbf{z}^{\text {global }}=\operatorname{Agg}(\operatorname{MLP}(\mathbf{x}), D) \in \mathbb{R}^{C^{\prime}}zglobal =Agg(MLP(x),D)RC

上面的 Agg 是将所有 token 中的信息聚合起来,由下面的全局池化函数实现:

Agg ⁡ ( u , D ^ ) = ∑ i = 1 N D ^ i u i ∑ i = 1 N D ^ i , u ∈ R N × C ′ \operatorname{Agg}(\mathbf{u}, \hat{\mathbf{D}})=\frac{\sum_{i=1}^{N} \hat{\mathbf{D}}_{i} \mathbf{u}_{i}}{\sum_{i=1}^{N} \hat{\mathbf{D}}_{i}}, \quad \mathbf{u} \in \mathbb{R}^{N \times C^{\prime}}Agg(u,D^)=i=1ND^ii=1ND^iui,uRN×C

局部特征中编码着某个 token 中的信息,而全局特征中编码着整张图片的上下文信息,它们都非常重要。因此将它们拼接在一起,得到局部-全局的嵌入,然后送入另一个 MLP,来预测 token 是否被丢弃/保留。

z i = [ z i local  , z i global  ] , 1 ≤ i ≤ N π = Softmax ⁡ ( MLP ⁡ ( z ) ) ∈ R N × 2 \begin{aligned}&\mathbf{z}_{i}=\left[\mathbf{z}_{i}^{\text {local }}, \mathbf{z}_{i}^{\text {global }}\right], \quad 1 \leq i \leq N \\&\boldsymbol{\pi}=\operatorname{Softmax}(\operatorname{MLP}(\mathbf{z})) \in \mathbb{R}^{N \times 2}\end{aligned}zi=[zilocal ,ziglobal ],1iNπ=Softmax(MLP(z))RN×2

(上式中是将 z zz 先经过一个 MLP 并预测出一个 binary value,然后送入 Softmax 并得到最终的结果,因此形状为 N × 2 N \times 2N×2

π i , 0 \boldsymbol{\pi}_{i, 0}πi,0 代表第 i ii 个 token 被丢弃的概率, π i , 1 \boldsymbol{\pi}_{i, 1}πi,1代表被保留的概率。

然后,从 π \piπ 中采样并生成当前的决策掩码 D DD,并更新 D ^ \hat DD^

D ^ ← D ^ ⊙ D \hat{\mathbf{D}} \leftarrow \hat{\mathbf{D}} \odot \mathbf{D}D^D^D

⊙ \odot 是哈德玛内积(Hadamard Product),以上表示一旦一个 token 被丢弃,就不会再被使用。

以下是预测模块的代码,参考自 https://zhuanlan.zhihu.com/p/380353779:

class PredictorLG(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, embed_dim=384):
        super().__init__()
        # local建模
        self.in_conv = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU()
        )
        # 为每个token预测一个两维的向量,来表示当前token是否需要被mask掉
        self.out_conv = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, embed_dim // 4),
            nn.GELU(),
            nn.Linear(embed_dim // 4, 2),
            nn.LogSoftmax(dim=-1)
        )

    def forward(self, x, policy):
        # x 表示当前输入的tokens
        # policy表示当前的mask,由0和1组成,0表示不需要参与后序计算的token
        x = self.in_conv(x)  # 对于输入的每一个token先经过一层linear projection对局部信息进行建模
        B, N, C = x.size()
        local_x = x[:,:, :C//2]
        # 在计算全局向量的时候,只对参与后序计算的token进行全局池化操作
        global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)
        # 将全局向量与局部向量拼接
        x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1)
        # 通过简单的MLP来输出每个token是否需要保留的一个分数,一组score
        return self.out_conv(x)

注意力掩码的端到端优化

首先,为了解决在采样后不可导的情况,使用了 Gumbel-Softmax 从 π \piπ 中进行采样

D = Gumbel ⁡ − Softmax ⁡ ( π ) ∗ , 1 ∈ { 0 , 1 } N \mathbf{D}=\operatorname{Gumbel}-\operatorname{Softmax}(\boldsymbol{\pi})_{*, 1} \in\{0,1\}^{N}D=GumbelSoftmax(π),1{0,1}N

Gumbel-Softmax 的输出是一个 one-hot 向量,且它的期望等于 π \piπ

至于第二点,由于前述原因,不能直接将判为需要丢弃的 token 直接丢弃,因此作者选择将需要被剪枝的 token 与其它 token 的连接剪断。因此提出一个注意力掩码方法。

P = Q K T / C ∈ R N × N \mathbf{P}=\mathbf{Q K}^{T} / \sqrt{C} \in \mathbb{R}^{N \times N}P=QKT/CRN×N

G i j = { 1 , i = j , D ^ j , i ≠ j 1 ≤ i , j ≤ N \mathbf{G}_{i j}=\left\{\begin{array}{ll}1, & i=j, \\\hat{\mathbf{D}}_{j}, & i \neq j\end{array} \quad 1 \leq i, j \leq N\right.Gij={1,D^j,i=j,i=j1i,jN

A ~ i j = exp ⁡ ( P i j ) G i j ∑ k = 1 N exp ⁡ ( P i k ) G i k , 1 ≤ i , j ≤ N \tilde{\mathbf{A}}_{i j}=\frac{\exp \left(\mathbf{P}_{i j}\right) \mathbf{G}_{i j}}{\sum_{k=1}^{N} \exp \left(\mathbf{P}_{i k}\right) \mathbf{G}_{i k}}, \quad 1 \leq i, j \leq NA~ij=k=1Nexp(Pik)Gikexp(Pij)Gij,1i,jN

用上面中间的那个公式来创建一个 graph,其中 G i j = 1 \mathbf{G}_{i j}=1Gij=1 意思是第 j jj个 token 会对第 i ii 个 token 的更新起作用。注意到此处加了一个自环(self-loop)来提高稳定性,能看出这个自环不会影响结果,即如果 D ^ j = 0 \hat{\mathbf{D}}_{j}=0D^j=0 ,那第 j jj 个 token 不会对除它自己外的任何 token 有影响。(这里也没太理解)

训练及推断

L c l s =  CrossEntropy  ( y , y ‾ ) \mathcal{L}_{\mathrm{cls}}=\text { CrossEntropy }(\mathbf{y}, \overline{\mathbf{y}})Lcls= CrossEntropy (y,y)

为了最小化稀疏 token 带来的影响,作者使用了原始的 backbone 模型来当作老师模型。此外,还有一个自蒸馏损失:

L distill  = 1 ∑ b = 1 B ∑ i = 1 N D ^ i b , S ∑ b = 1 B ∑ i = 1 N D ^ i b , S ( t i − t i ′ ) 2 \mathcal{L}_{\text {distill }}=\frac{1}{\sum_{b=1}^{B} \sum_{i=1}^{N} \hat{\mathbf{D}}_{i}^{b, S}} \sum_{b=1}^{B} \sum_{i=1}^{N} \hat{\mathbf{D}}_{i}^{b, S}\left(\mathbf{t}_{i}-\mathbf{t}_{i}^{\prime}\right)^{2}Ldistill =b=1Bi=1ND^ib,S1b=1Bi=1ND^ib,S(titi)2

上式中 t i t_itit i ′ t'_iti 分别表示第 i ii 个 token 从 DynamicViT 和老师模型最后一个 block 中输出的结果,

D ^ i b , S \hat{\mathbf{D}}_{i}^{b, S}D^ib,S 表示第 b bb 个样本从第 s ss 个 stage 中输出的决策掩码。此外,还要最小化 DynamicVit 和老师模型的输入结果:

L K L = K L ( y ∥ y ′ ) \mathcal{L}_{\mathrm{KL}}=\mathrm{KL}\left(\mathbf{y} \| \mathbf{y}^{\prime}\right)LKL=KL(yy)

最后,还需要将需要保留的 token 的比例限制在一个预设的值。当给定一组这个比例时,需要最小化以下损失:

L ratio  = 1 B S ∑ b = 1 B ∑ s = 1 S ( ρ ( s ) − 1 N ∑ i = 1 N D ^ i b , s ) 2 \mathcal{L}_{\text {ratio }}=\frac{1}{B S} \sum_{b=1}^{B} \sum_{s=1}^{S}\left(\rho^{(s)}-\frac{1}{N} \sum_{i=1}^{N} \hat{\mathbf{D}}_{i}^{b, s}\right)^{2}Lratio =BS1b=1Bs=1S(ρ(s)N1i=1ND^ib,s)2

因此,总的损失函数便是:

L = L c l s + λ K L L K L + λ distill  L distill  + λ ratio  L ratio  \mathcal{L}=\mathcal{L}_{\mathrm{cls}}+\lambda_{\mathrm{KL}} \mathcal{L}_{\mathrm{KL}}+\lambda_{\text {distill }} \mathcal{L}_{\text {distill }}+\lambda_{\text {ratio }} \mathcal{L}_{\text {ratio }}L=Lcls+λKLLKL+λdistill Ldistill +λratio Lratio 

其中 λ K L = 0.5 , λ distill  = 0.5 , λ ratio  = 2 \lambda_{\mathrm{KL}}=0.5, \lambda_{\text {distill }}=0.5, \lambda_{\text {ratio }}=2λKL=0.5,λdistill =0.5,λratio =2

结果展示

以下是经过每个 stage,并丢弃一定的 token 后,图片的样子。
在这里插入图片描述


版权声明:本文为ttppss原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。