博客专栏

EEPW首页 > 博客 > 计算机视觉中的知识蒸馏

计算机视觉中的知识蒸馏

发布人:计算机视觉工坊 时间:2022-05-15 来源:工程师 发布文章

作者:ppog@知乎(已授权转载)

编辑:CV技术指南

原文:https://zhuanlan.zhihu.com/p/497067556

前段时间熬完毕设的工作,趁着空闲想写一篇关于知识蒸馏的博客,这是本人读研期间的一个研究方向,但这篇博客不会过于深入,内容大概简短说说自己对于知识蒸馏的一些看法,大多数内容来源于四月份看到的两篇paper。鄙人愚见,有不当之处欢迎批评!

文中涉及到的三篇论文

Distilling the Knowledge in a Neural Network

paper:arxiv.org/pdf/1503.0253

code:github.com/labmlai/anno

Solving ImageNet: a Unified Scheme for Training any Backbone to Top Results

paper:arxiv.org/pdf/2204.0347

code:github.com/Alibaba-MIIL

Decoupled Knowledge Distillation

paper:arxiv.org/abs/2203.0867

code:github.com/megvii-resea

1、知识铺垫
one hot 编码

one-hot 编码(one-hot encoding)类似于虚拟变量(dummy variables),是一种将分类变量转换为几个二进制列的方法,即一种硬编码形式,类似非黑即白。其中 1 代表某个输入样本属于该类别。

图片

soft label

深度学习领域中,通常将数据标注为hard label,但事实上同一个数据包含不同类别的信息,直接标注为hard label无法显示图像数据间的相关性,例如分类任务中,数据样本(下图)的hard label是【sheep:1】,而实际上,样本中包含了一条狗,对应的soft label可能是【sheep:0.90;dog:0.10】。

图片

基于上述事实:

  1. hard label会根据照片,告诉我们这就是羊,其他都不是;

  2. soft label会告诉我们,这张照片大概率是羊,存在一定概率是狗。

但在实际应用中,两者均有其所长:hard label虽然更容易标注,但是会丢失类内、类间的关联。而soft label能给模型带来更强的泛化,携带更多的信息,但是获取难度会比hard label大。


迁移学习和知识蒸馏

总的来说,两者都属于知识迁移的一种,知识蒸馏是模型层面的迁移方式,而迁移学习是数据层面的迁移方式。

具体而言,两个在一定程度下都可以实现涨点,以ImageNet-1K、ImageNet-21K、ResNet18、ResNet31为例(假设验证集恒不变):

  • 对于迁移学习,我们使用ResNet18在ImageNet-21K上进行预训练,训练完后将模型迁移到ImageNet-1K上微调,在验证集不变的情况,精度会更高。

  • 对于知识蒸馏,我们使用ResNet32作为Teacher模型在ImageNet-1K上进行训练,ResNet18作为Student模型同样也在ImageNet-1K上训练,但会引入训练完后的Teacher模型做监督,往往精度也会提高。

但两种方式都会带来一些问题,例如训练周期更长,更大的计算开销,更严重的资源占用等等。


2. 什么是KD?

《Distilling the Knowledge in a Neural Network》是知识蒸馏的开山鼻祖,于2015年提出,目前引用量快超过10k。其提出来的带温度的kl散度损失是最早的分类算法蒸馏方案,由于是基于logits的蒸馏方式,易于复现,后续也有许多在KL散度上进行改进的版本。

图片

KD所需基本对象

图片

Knowledge Distillation 的整体示意如上图所示(基于logits):

  1. Teacher model:结构较为复杂,特征提取能力更强的大模型,如ResNet31

  2. Student model:结构较为简单,特征提取能力一般的小模型,如ResNet18

  3. Hard label:输入数据所对应的类别,上文开头解释过了,常规的训练一般都是使用的Hard label

  4. Soft label:输入数据通过Teacher模型softmax层的输出,蒸馏训练附加的loss基于此得来

  5. distill loss:蒸馏采用的损失可能是KL、MSE、CE等,该论文采用的是基于温度T的KL Loss


KD常见步骤

围绕这几个基本点,共进行步骤如下(假设数据集为cifer):

① ResNet31在cifer数据上训练得到的教师模型

② 将教师模型的prediction软化,即输入数据通过teacher model所得到的softmax层的输出:

图片

③ 得到软化的预测向量后,通过KL散度损失进行下一步计算:

图片

可以看下代码实现:

# y_s: student output logits 
# y_t: teacher output logits 
# T: temperature for KD
# teacher model: resnet31
# student model: resnet18

class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        KLDLoss = nn.KLDivLoss(reduction="none")
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.log_softmax(y_t/self.T, dim=1)
        loss = KLDLoss(p_s, p_t) * (self.T**2)
        return loss

④ 在计算出蒸馏的loss后,将这个kl_loss附加在原始的分类损失(假设是CE loss)上:

图片

在经过知识蒸馏的操作后,模型精度得到了提升,但当时开展的相关实验比较少,毕竟是在2014年,各方面条件都有所限制,且文中作者也没有十分详细地解释蒸馏能让模型提升的具体原因。


3. 为什么KD有效?

时间来到了2022年,在4月7日,阿里达摩院在arvix上挂上了《Solving ImageNet》,该论文主要针对目前的计算机视觉模型,提出通用的训练方案USI,并且该方案主要基于KD蒸馏的训练方式。

不过在我看来,该论文展示了许多丰富的实验及结果,并且验证和解释了为何KD是有效的,更像是对14年提出的KD进行详尽的补充。


KD适用的架构

文中提到,目前的计算机视觉模型大致下可以分为四类:

  • 类似ResNet的常规CNN模型(ResNet-like)

  • 面向移动端的轻量模型(Mobile-oriented)

  • Transformer模型(Transformer-base)

  • 仅包含MLP的模型(MLP-only)

该作者对上述四种架构的计算机视觉模型抽样进行了实验,有意思的是,使用基于KD方式的训练方案的模型在Top-1上均获得了不同程度的提高,特别是Mobile-oriented类的轻量模型。

图片

KD的有效性分析

为了更深入地了解KD对模型结果的影响,作者在下图中展示了一些教师模型预测的标签,与ImageNet真实标签的对比。

图片

  1. 图片(a)包含了大量明显的钉子,教师模型的预测是99.9%,而第二和第三个预测也与钉子(螺丝和锤子)相关,但概率值可以忽略不计。

  2. 图片(b)中包含了一架客机,教师模型的最高预测是客机(83.6%)。然而,教师模型也有一些不能忽视的概率(11.3%)。这并非是错误,因为飞机上有机翼。这里的教师模型减轻了实际情况与真实标签相互排斥的情况(即要么是1,要么是0),并提供了关于图像内容更准确的信息(打个比方,前面提到的一张图基本都是羊,但有一条狗,数据集的分类标签是羊,但teacher教师预测时会留出部分概率给了狗)。

  3. 图片(c)中包含了一只母鸡。然而,母鸡的信息并非很明显,教师模型的预测反映了这一点,通过识别出一只概率较低的母鸡(55.5%),还给出了一定的概率给公鸡( 大约8.9%.)。虽然这是教师模型的错误,但实际上就算是人,这么小的目标似乎也很难一下子分得清。

  4. 在图片(d)中,教师模型认为真实标签是错误的。真实标签是冰棍,而教师模型预测概率最大的是狗。作者认为教师模型的预测反而是对的,因为狗在图片中的信息更为突出。

从上面的例子中可以看到,教师模型的预测比简单( 0或1)的真实标签包含了更丰富的信息,soft label解释了类别之间的相关性。不仅如此,KD更能代表增强过后图像的正确信息,能更好处理strong augmentations的问题。由于上述提到的原因,与仅使用hard label的训练相比,使用教师模型的soft label进行训练会提供有更有效的监督,训练会变得更有效、更稳健。


4. 如何让KD更加有效?

上边讲到,KD有作用,但究竟是哪部分起作用,作用多大,是否存在负优化,值得思考!

在今年的3月16日,旷视对KD(KL Loss)进行了更加深入的剖析,提出了解耦蒸馏(《Decoupled Knowledge Distillation》,DKD),这篇文章很精彩,对14年提出的KD(KL Loss)进行了多方位的解析,也开展了许多实验。

图片

如上图所示,研究者将 logits 拆解成两部分,蓝色部分指目标类别(target class)的 score,绿色部分指非目标类别(Non-target class)的 score。并且将KD重新表述为两部分的加权和,即 TCKD 和 NCKD


公式演进

图片

图片

图片

上述定义和数学关系将帮助我们得到 KL Loss 的新表达形式:图片

图片

图片

对于公式的补充解释:

图片


更有说服力的实验

为了观察TCKD 和 NCKD 对蒸馏性能的影响,作者做了大量实验,并试图通过实验剖析TCKD 和 NCKD 的作用。

图片

上图为TCKD 和 NCKD在CIFAR-100 上进行的实验,作者初步得出以下结论:

  • 同时使用 TCKD + NCKD = KD 的蒸馏方式,Student模型均获得不同程度的提升;

  • 单独使用 TCKD 进行蒸馏,会对蒸馏效果产生较大的损害,原因在于高温系数(T)会导致损失附加上很大的梯度,增加非目标类的 logits ,这会损害学生预测的正确性;

  • 单独使用 NCKD 进行蒸馏,和 KD 效果差不多;

基于上述结论,是否 NCKD 更加有效,而 TCKD 存在负优化?作者给出了进一步的探讨。

作者认为 TCKD 受限于数据集的难易程度,假设一个样本经过教师模型后输出概率是0.99,说明这个样本是易样本,数据集是容易分辨的,而当概率只有0.75,甚至是0.55,那么样本会陷入到模棱两可的状态,模型也没有把握认定它就是所谓的那个它(你那么爱它,为什么不把它留下),数据集难度增加。

作者补充了以下三个实验:更重的数据增强;更多的噪声;更复杂的数据。

1、更重的数据增强

图片

上表显示Teacher模型为ResNet32×4,Student模型为ShuffleNet-V1和ResNet8×4的实验结果,在使用 AutoAugment数据增强方法的情况下,训练集难样本系数增大,此时使用 TCKD 可以达到较大的提升。

2、更多的噪声

图片

而通过引入噪声,当噪声比例增大,TCKD 的提升程度也加强。

3、更复杂的数据

图片

使用ResNet34作为Teacher模型,ResNet18作为Student模型,作者发现学生模型的Top-1增加了0.32个点。

最后,作者给出的结论是,通过尝试各种策略来增加训练数据的复杂度(例如重的数据增强、更多的噪声、困难的任务)来证明 TCKD 的有效性。结果证实,在对更具挑战性的训练数据进行知识蒸馏时,训练样本“复杂度(难度)”的提升对于 TCKD 可能更有增益,说明 TCKD 对于数据集中复杂任务的监督能力更强。

而上上上部分,作者也证实了NCKD 能力出众,这也反映了一个事实:说明非目标类之间的知识对logits的蒸馏方式至关重要,它们可以比喻为能力出众的“暗部成员”(知道卡卡西吗?),论文中称之为“暗知识”(dark knowledge)。

如何理解?大家可以把目标类别的logits看作是light knowledge,按照我们惯有的思维,目标类别是最重要的,我想要识别出一条狗,那么我就会找一大堆关于该目标类别的样本,不断填充和丰富它的logits信息,而非目标类别则显得不那么重要,因为我们想要kill的名单中没有他们,但不可置否,dark knowledge对于模型泛化性也非常关键。

图片

依据 Teacher 模型预测的置信度,作者对cifer训练集上的样本做了排序,根据排序结果对数据集进行切分,置信度0.5-1为一块,置信度为0-0.5为一块,实验结果如下:

图片

在前 50% 的样本上使用 NCKD 可以获得更好的性能,这表明预测良好的样本所携带的知识比其他样本更丰富。然而,预测良好的样本的损失权重被教师的高置信度所抑制。这也说明了,置信度高的样本对蒸馏的效果更加显著,应当采取措施让它们不被抑制。


来自七年后的plus版本

图片


5. 实验效果

分类任务

图片

作者使用DKD和KD进行对比,效果都要优于KD(KL Loss)的方式,在不同模型上实现了1-2,甚至是3个点的提升。

图片

并且,作者对一些细节也进行了补充,通常a设置为1时效果较好,而实际应用中变动较大的为Beta,当具体调为何值,需要根据实际的业务数据进行实验。

检测任务

图片

作者使用了Faster rcnn作为baseline,通过替换不同的backbone以此作为teacher和student,可以看出,DKD的方式带来的提升均超过了原始KD的方式,而将DKD与基于Feature蒸馏结合起来组成的DKD+ReviewDKD提升更大。这也证明了,检测任务十分依赖于feature的定位能力,而logits这种high level的信息并不具备这种能力,这也使得基于logits的蒸馏方式效果差于feature的蒸馏,但总的来说,KD的解耦型DKD还是展示了更加优越的性能。


总结

这篇博客从三个层面讲述了KD是什么?为什么有效?突然想写这篇博客,原因在于四月份看到的两篇论文解答了我之前在这个方向上的不少疑惑,随整理出来。但由于本人并未涉略过深,仍会有很多理解不足的地方,也欢迎各位大佬批评指正!

参考文献

[1] pprp:知识蒸馏综述:代码整理

[2] medium.com/analytics-vi

[3] 从标签平滑和知识蒸馏理解Soft Label

[4] [论文阅读]知识蒸馏(Distilling the Knowledge in a Neural Network)

[5] Distilling the Knowledge in a Neural Network 论文笔记

[6] oldsummer:2021 《Knowledge Distillation: A Survey》

[7] CVPR 2022|解耦知识蒸馏!旷视提出DKD:让Hinton在7年前提出的方法重回SOTA行列!

[8] 阿里巴巴提出USI 让AI炼丹自动化了,训练任何Backbone无需超参配置,实现大一统!

本文仅做学术分享,如有侵权,请联系删文。



*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。



关键词: AI

相关推荐

技术专区

关闭