文章目录
什么是知识蒸馏?
- 知识蒸馏是指从大模型(Teacher model)中学习到有用的知识来训练小模型(Student model),在保证不损失太多性能的情况下,进行模型压缩。
- 最早是为了解决模型压缩(轻量化)问题。
- 在蒸馏过程中,student model 学习到 teacher model 的泛化能力,保留了接近 teacher model 的性能。 在保留精度的同时,能压缩模型,提升速度。但只在分类任务上得到了印证,在更复杂的 object detection 上还有待探索。
目标检测中的知识蒸馏
- 目标检测任务 label 信息量更大,根据 label 学到的模型更为复杂,压缩后损失更多;
- 分类任务中,每个类别相对均衡,同等重要。而目标检测任务中,存在类别不平衡问题,背景类偏多;
- 目标检测任务更为复杂,既有类别分类,也有位置回归的预测;
- 现行的知识蒸馏主要针对同一域中数据进行蒸馏,对于跨域目标检测的任务而言,对知识的蒸馏有更高的要求。
[NIPS17] Learning Efficient Object Detection Models with Knowledge Distillation
Introduction:
主要是通过设置三个 loss 函数,分别对 backbone、cls head、reg head 进行蒸馏:
- 对于 backbone: 使用 hint learning进行蒸馏,增加一个 adaptation layers,让 feature map 的维度匹配;
- 对于分类任务:使用 weighted CE Loss 解决类别失衡严重问题;
- 对于回归任务:除了原本的 smooth-ℓ 1 \ell_1ℓ1 loss,增加 teacher bounded regression loss。
Method
教师网络的知识提取分为三点:**中间层 Feature Maps 的 Hint;RPN/RCN 中分类层的 knowledge;以及RPN/RCN 中回归层的 knowlege。**具体如下:
L R C N = 1 N ∑ i L c l s R C N + λ 1 N ∑ j L r e g R C N L R P N = 1 M ∑ i L c l s R P N + λ 1 N ∑ j L r e g R P N L = L R P N + L R C N + γ L H i n t L_{RCN}=\frac{1}{N}\sum_iL_{cls}^{RCN}+\lambda \frac{1}{N}\sum_jL_{reg}^{RCN}\\ L_{RPN}=\frac{1}{M}\sum_iL_{cls}^{RPN}+\lambda \frac{1}{N}\sum_jL_{reg}^{RPN}\\ L=L_{RPN}+L_{RCN}+\gamma L_{Hint}LRCN=N1i∑LclsRCN+λN1j∑LregRCNLRPN=M1i∑LclsRPN+λN1j∑LregRPNL=LRPN+LRCN+γLHint
- N NN 和 M MM 分别是对应部分的batch-size大小,λ \lambdaλ 和 γ \gammaγ 是超参数(这里分别设定为 1 11 和 0.5 0.50.5);
- L c l s L_{cls}Lcls 包括 hard target 和知识蒸馏中的 soft target;
- L r e g L_{reg}Lreg 包括 smooth-ℓ 1 \ell_1ℓ1 和新提出的 teacher bounded ℓ 2 \ell_2ℓ2 regression loss;
- L H i n t L_{Hint}LHint 为主干网络的损失。
分类任务中的类别不均衡现象
教师网络和学生网络的输出分别如下:
P t = softmax ( Z t T ) P s = softmax ( Z s T ) P_t=\text{softmax}(\frac{Z_t}{T})\\ P_s=\text{softmax}(\frac{Z_s}{T})Pt=softmax(TZt)Ps=softmax(TZs)
学生网络的优化损失如下:
L c l s = μ L h a r d ( P s , y ) + ( 1 − μ ) L s o f t ( P s , P t ) L_{cls}=\mu L_{hard}(P_s,~y)+(1-\mu)L_{soft}(P_s,~P_t)Lcls=μLhard(Ps, y)+(1−μ)Lsoft(Ps, Pt)
- L h a r d L_{hard}Lhard 是用 gt 监督的 Cross Entropy
- L s o f t L_{soft}Lsoft 是用教师网络的信息监督的 soft loss。
分类任务中, 分类错误只会来自 foreground categories。目标检测中的分类子任务,background and foreground categories 都会导致错分。
- 对于分类损失中的 background 误分概率占比较高的情况,提出增大蒸馏交叉熵中背景类的权重来解决失衡问题。
L s o f t ( P s , P t ) = − ∑ w c P t log P s L_{soft}(P_s,~P_t)=-\sum w_cP_t\text{log}P_sLsoft(Ps, Pt)=−∑wcPtlogPs
回归任务
对于回归结果的蒸馏,**regression direction 可能和 gt 相差较大:**由于回归的输出是无界的,教师网络的预测方向可能与 gt 的方向相反。因此,将教师的输出损失作为上界,当学生网络的输出损失大于上界时,计入该损失;否则不考虑该 loss。
L b ( R S , R t , y ) = { ∥ R s − y ∥ 2 2 , if ∥ R s − y ∥ 2 2 + m > ∥ R t − y ∥ 2 2 0 , otherwise L r e g = L s m o o t h − ℓ 1 ( R S , y r e g ) + ν L b ( R s , R t , y r e g ) L_b(R_S,~R_t,~y)= \begin{cases} \|R_s-y\|^2_2,~&\text{if}~\|R_s-y\|^2_2+m>\|R_t-y\|^2_2\\ 0,~&\text{otherwise} \end{cases} \\ L_{reg}=L_{smooth-\ell_1}(R_S,~y_{reg})+\nu L_b(R_s,~R_t,~y_{reg})Lb(RS, Rt, y)={∥Rs−y∥22, 0, if ∥Rs−y∥22+m>∥Rt−y∥22otherwiseLreg=Lsmooth−ℓ1(RS, yreg)+νLb(Rs, Rt, yreg)
- m is a margin,权重 ν = 0.5 \nu=0.5ν=0.5;
- y r e g y_{reg}yreg denotes the regression ground truth label,是 proposal 和 gt 之间的回归量;
- R t R_tRt 和 R s R_sRs 分别是 teacher 和 student 网络学出来的回归量;
- L s m o o t h − ℓ 1 L_{smooth-\ell_1}Lsmooth−ℓ1是普通的 smooth ℓ 1 \ell_1ℓ1 回归 loss。
Hint learning with Feature Adaption
论文中证明,using the intermediate representation of the teacher as hint can help the training process and improve the final performance of the student.
L = L R P N + L R C N + γ L H i n t L=L_{RPN}+L_{RCN}+\gamma L_{Hint}L=LRPN+LRCN+γLHint
其中 L H i n t L_{Hint}LHint是学生网络 backbone 的loss:
L H i n t ( V , Z ) = ∥ V − Z ∥ 2 2 L H i n t ( V , Z ) = ∥ V − Z ∥ 1 2 L_{Hint}(V,~Z)=\|V-Z\|^2_2\\ L_{Hint}(V,~Z)=\|V-Z\|^2_1LHint(V, Z)=∥V−Z∥22LHint(V, Z)=∥V−Z∥12
变量 V , Z V,~ZV, Z 分别是教师网络和学生网络的 feature map(全 feature imitation),需要加入 adaption layer 使得二者维度相同。