精通知识蒸馏用于 LLMS:5 个关键技术和实际应用
知识蒸馏通过创建更小、更快、更易于部署的模型,释放了 LLM 在实际应用中的潜力。本文提供了知识蒸馏的全面指南,涵盖了视觉、NLP 和语音领域的算法、架构和应用
大规模机器学习和深度学习模型变得越来越普遍。例如,据报道 GPT-4o 拥有超过 2000 亿个参数。然而,虽然训练大型模型有助于提高最先进的性能,但部署这种笨重的模型,尤其是在边缘设备上,并非易事。
此外,大多数数据科学建模工作都侧重于训练单个大型模型或不同模型的集合,以便在保留的验证集上表现良好,而该验证集通常不能代表真实世界的数据。
这种训练和测试目标之间的不一致导致了机器学习模型的开发,这些模型在精心设计的验证数据集上产生了良好的准确性,但在实际测试数据上的推理时,往往无法满足性能、延迟和吞吐量基准。
知识蒸馏通过捕获并将复杂机器学习模型或模型集合中的知识“蒸馏”到一个更小的单个模型中来帮助克服这些挑战,该模型更容易部署,而不会显着降低性能。
在这篇博文中,我将:
- 详细描述知识蒸馏,其潜在的原理、训练方案和算法;
- 深入研究深度学习中知识蒸馏在图像、文本和音频方面的应用。
什么是知识蒸馏?
知识蒸馏是指将知识从一个大型笨重的模型或一组模型转移到一个可以在实际约束下实际部署的单个较小模型的过程中。从本质上讲,它是一种模型压缩形式,最早由 Bucilua 及其合作者在 2006 年成功演示。
知识蒸馏更常在与复杂架构相关的神经网络模型上执行,这些架构包括多个层和模型参数。因此,随着过去十年深度学习的出现及其在语音识别、图像识别和自然语言处理等不同领域的成功,知识蒸馏技术在实际应用中获得了突出地位。
部署大型深度神经网络模型的挑战对于内存和计算能力有限的边缘设备尤其重要。为了应对这一挑战,首先提出了一种模型压缩方法,将知识从一个大型模型转移到训练一个较小的模型中,而不会显着降低性能。这种从较大的模型中学习较小模型的过程被 Hinton 及其同事正式化为“知识蒸馏”框架。
如图 1 所示,在知识蒸馏中,一个小的“学生”模型学习模仿一个大的“教师”模型,并利用教师的知识来获得相似或更高的准确性。在下一节中,我将更深入地研究知识蒸馏框架及其底层架构和机制。
深入研究知识蒸馏
知识蒸馏系统由三个主要部分组成:知识、蒸馏算法和师生架构。
知识
在神经网络中,知识通常指学习到的权重和偏差。同时,大型深度神经网络中的知识来源也多种多样。典型的知识蒸馏使用 logits 作为教师知识的来源,而其他则侧重于中间层的权重或激活。其他类型的相关知识包括不同类型的激活和神经元之间的关系,或者教师模型本身的参数。
不同形式的知识被分为三种不同的类型:基于响应的知识、基于特征的知识和基于关系的知识。图 2 说明了来自教师模型的这三种不同类型的知识。我将在下一节详细讨论这些不同的知识来源。
1. 基于响应的知识
如图 2 所示,基于响应的知识侧重于教师模型的最终输出层。假设是学生模型将学习模仿教师模型的预测。如图 3 所示,这可以通过使用一个损失函数来实现,称为蒸馏损失,该损失函数捕获学生模型和教师模型 logits 之间的差异。随着训练中此损失的最小化,学生模型将变得更擅长做出与教师模型相同的预测。
在图像分类等计算机视觉任务的背景下,软目标包含基于响应的知识。软目标表示输出类别的概率分布,通常使用 softmax 函数进行估计。每个软目标对知识的贡献都使用一个称为温度的参数进行调制。基于软目标的基于响应的知识蒸馏通常用于监督学习的背景下。
2. 基于特征的知识
经过训练的教师模型还在其中间层中捕获数据的知识,这对于深度神经网络尤其重要。中间层学习区分特定特征,并且此知识可用于训练学生模型。如图 4 所示,目标是训练学生模型学习与教师模型相同的特征激活。蒸馏损失函数通过最小化教师模型和学生模型的特征激活之间的差异来实现这一点。
3. 基于关系的知识
除了神经网络的输出层和中间层中表示的知识之外,捕获特征图之间关系的知识也可用于训练学生模型。这种形式的知识,称为基于关系的知识,如图 5 所示。这种关系可以建模为特征图之间的相关性、图、相似度矩阵、特征嵌入或基于特征表示的概率分布。
训练
有三种主要的训练学生模型和教师模型的方法,即离线蒸馏、在线蒸馏和自蒸馏。蒸馏训练方法的分类取决于教师模型是否与学生模型同时修改,如图所示
1. 离线蒸馏
离线蒸馏是最常见的方法,其中使用预训练的教师模型来指导学生模型。在该方案中,首先在训练数据集上预训练教师模型,然后从教师模型中提取知识来训练学生模型。鉴于深度学习的最新进展,可以公开获得各种预训练的神经网络模型,这些模型可以根据用例充当教师。离线蒸馏是深度学习中一种成熟的技术,并且易于实现。
2. 在线蒸馏
在离线蒸馏中,预训练的教师模型通常是一个大容量的深度神经网络。对于某些用例,可能无法进行离线蒸馏的预训练模型。为了解决此限制,可以使用在线蒸馏,其中教师模型和学生模型都在单个端到端训练过程中同时更新。在线蒸馏可以使用并行计算进行操作,从而使其成为一种高效的方法。
3. 自蒸馏
如图 6 所示,在自蒸馏中,同一模型用于教师模型和学生模型。例如,深度神经网络较深层的知识可用于训练较浅的层。它可以被认为是在线蒸馏的一种特殊情况,并以多种方式实例化。教师模型较早时期的知识可以转移到其后期时期以训练学生模型。
架构
学生-教师网络架构的设计对于高效的知识获取和蒸馏至关重要。通常,更复杂的教师模型和更简单的学生模型之间存在模型容量差距。可以通过优化知识转移(通过高效的学生-教师架构)来缩小这种结构差距。
由于深度神经网络的深度和广度,从深度神经网络转移知识并不简单。用于知识转移的最常见架构包括一个学生模型,该模型是:
- 教师模型的一个较浅版本,具有较少的层和每层较少的神经元,
- 教师模型的量化版本,
- 具有高效基本运算的较小网络,
- 具有优化的全局网络架构的较小网络,
- 与教师相同的模型。
除了上述方法外,最近的进展(如神经架构搜索)也可用于设计给定特定教师模型的最佳学生模型架构。
知识蒸馏算法
在本节中,我将重点介绍用于训练学生模型以从教师模型中获取知识的算法。
1. 对抗蒸馏
对抗学习最近在生成对抗网络(GAN)的背景下被概念化,用于训练生成器模型,该模型学习生成与真实数据分布尽可能接近的合成数据样本,以及鉴别器模型,该模型学习区分真实数据样本和合成数据样本。此概念已应用于知识蒸馏,以使学生模型和教师模型能够更好地表示真实的数据分布。
为了满足学习真实数据分布的目标,对抗学习可用于训练生成器模型以获得合成训练数据,以用作此类数据或增强原始训练数据集。第二种基于对抗学习的蒸馏方法侧重于鉴别器模型,以根据 logits 或特征图区分来自学生模型和教师模型的样本。此方法有助于学生很好地模仿教师。第三种基于对抗学习的蒸馏技术侧重于在线蒸馏,其中学生模型和教师模型被联合优化。
2. 多教师蒸馏
在多教师蒸馏中,学生模型从几个不同的教师模型中获取知识,如图 7 所示。使用教师模型的集合可以为学生模型提供不同类型的知识,这些知识可能比从单个教师模型获得的知识更有益。
来自多个教师的知识可以组合为所有模型的平均响应。通常从教师那里转移的知识类型基于 logits 和特征表示。
3. 跨模态蒸馏
图 8 显示了跨模态蒸馏训练方案。在这里,教师在一个模态中进行训练,其知识被蒸馏到需要来自不同模态的知识的学生中。当在训练或测试期间无法获得特定模态的数据或标签时,就会出现这种情况,因此需要跨模态转移知识。
跨模态蒸馏最常用于视觉领域。例如,在标记图像数据上训练的教师的知识可用于蒸馏具有未标记输入域(如光流或文本或音频)的学生模型。在这种情况下,来自教师模型的图像中学习的特征用于学生模型的监督训练。跨模态蒸馏对于视觉问答、图像字幕等应用很有用。
4. 其他
除了上述讨论的蒸馏算法之外,还有几种其他算法已被应用于知识蒸馏。
- 基于图的蒸馏 使用图来捕获数据内的关系,而不是从教师到学生的单个实例知识。图以两种方式使用——作为知识转移的手段,并控制教师知识的转移。在基于图的蒸馏中,图的每个顶点代表一个自监督的教师,该教师可能基于响应或基于特征的知识,如 logits 和特征图。
- 基于注意力的蒸馏 基于使用注意力图从特征嵌入中转移知识。
- 无数据蒸馏 基于合成数据,由于隐私、安全或机密性的原因,没有训练数据集。合成数据通常由预训练的教师模型的特征表示生成。在其他应用中,GAN 也被用于生成合成训练数据。
- 量化蒸馏 用于将知识从高精度教师模型(例如 32 位浮点数)转移到低精度学生网络(例如 8 位)。
- 终身蒸馏 基于持续学习、终身学习和元学习的学习机制,其中先前学习的知识被积累并转移到未来的学习中。
- 基于神经架构搜索的蒸馏 用于识别合适的学生模型架构,以优化从教师模型中学习。
知识蒸馏的应用
知识蒸馏已成功应用于多个机器学习和深度学习用例,如图像识别、NLP 和语音识别。在本节中,我将重点介绍现有的应用以及知识蒸馏技术的未来潜力。
1. 视觉
知识蒸馏在计算机视觉领域的应用非常广泛。最先进的计算机视觉模型越来越多地基于深度神经网络,这些神经网络可以从模型压缩中受益以进行部署。知识蒸馏已成功应用于以下用例:
- 图像分类,
- 人脸识别,
- 图像分割,
- 动作识别,
- 目标检测,
- 车道线检测,
- 行人检测,
- 面部地标检测,
- 姿势估计,
- 视频字幕,
- 图像检索,
- 阴影检测,
- 文本到图像的合成,
- 视频分类,
- 视觉问答等。
知识蒸馏也可用于交叉分辨率人脸识别等特定用例,其中基于高分辨率人脸教师模型和低分辨率人脸学生模型的架构可以提高模型性能和延迟。由于知识蒸馏可以利用不同类型的知识,包括跨模态数据、多域、多任务和低分辨率数据,因此可以为特定的视觉识别用例训练各种蒸馏学生模型。
2. NLP
鉴于语言模型或翻译模型等大型深度神经网络的普遍性,知识蒸馏在 NLP 应用中的应用尤其重要。最先进的语言模型包含数十亿个参数,例如,GPT-3 包含 1750 亿个参数。这比以前最先进的语言模型 BERT 大几个数量级,BERT 的基本版本包含 1.1 亿个参数。
因此,知识蒸馏在 NLP 中非常受欢迎,以获得快速、轻量级的模型,这些模型更易于训练且计算成本更低。除了语言建模之外,知识蒸馏还用于 NLP 用例,如:
- 神经机器翻译
- 文本生成
- 问答
- 文档检索
- 文本识别
使用知识蒸馏,可以获得高效且轻量级的 NLP 模型,这些模型可以以较低的内存和计算需求进行部署。学生-教师训练也可用于解决多语言 NLP 问题,其中多语言模型的知识可以相互转移和共享。
案例研究:DistilBERT
DistilBERT 是由 Hugging Face 开发的更小、更快、更便宜和更轻的 BERT 模型。在这里,作者预先训练了一个较小的 BERT 模型,该模型可以在各种 NLP 任务上进行微调,并具有相当强的准确性。在预训练阶段应用了知识蒸馏,以获得 BERT 模型的蒸馏版本,该版本缩小了 40%(6600 万个参数 vs 1.1 亿个参数),速度提高了 60%(在 GLUE 情感分析任务上的推理时间为 410 秒 vs 668 秒),同时保持了与原始 BERT 模型准确度的 97% 相当的模型性能。在 DistilBERT 中,学生与 BERT 具有相同的架构,并且是使用一种新颖的三元组损失获得的,该损失结合了与语言建模、蒸馏和余弦距离损失相关的损失。
3. 语音
最先进的语音识别模型也基于深度神经网络。现代 ASR 模型经过端到端训练,并基于包括卷积层、带有注意力的序列到序列模型以及最近的转换器在内的架构。对于实时、设备上的语音识别,获得更小、更快的模型以获得有效性能变得至关重要。
语音中知识蒸馏有几个用例:
- 语音识别
- 口语识别
- 音频分类
- 说话人识别
- 声学事件检测
- 语音合成
- 语音增强
- 噪声鲁棒 ASR
- 多语言 ASR
- 口音检测
案例研究:亚马逊 Alexa 的声学建模
Parthasarathi 和 Strom (2019) 利用学生-教师训练为 100 万小时的未标记语音数据生成软目标,其中训练数据集仅包含 7000 小时的标记语音。教师模型在所有输出类别上产生概率分布。学生模型也为给定的相同特征向量生成输出类别上的概率分布,并且目标函数优化了这两个分布之间的交叉熵损失。在这里,知识蒸馏有助于简化在大量语音数据上生成目标标签。
结论
现代深度学习应用基于具有大容量、内存占用和慢推理延迟的繁琐神经网络。将此类模型部署到生产环境中是一个巨大的挑战。知识蒸馏是一种优雅的机制,用于训练从大型、复杂的教师模型派生的更小、更轻、更快、更便宜的学生模型。继Hinton及其同事(2015)对知识蒸馏的概念化之后,知识蒸馏方案在获取用于生产用例的高效轻量级模型方面得到了大规模的采用。知识蒸馏是一种复杂的技术,基于不同类型的知识、训练方案、架构和算法。知识蒸馏已经在包括计算机视觉、自然语言处理、语音等在内的不同领域取得了巨大的成功。
参考文献
[1] Distilling the Knowledge in a Neural Network. Hinton G, Vinyals O, Dean J (2015) NIPS Deep Learning and Representation Learning Workshop. https://arxiv.org/abs/1503.02531
[2] Model Compression. Bucilua C, Caruana R, Niculescu-Mizil A (2006) https://dl.acm.org/doi/10.1145/1150402.1150464
[3] Knowledge distillation: a survey. You J, Yu B, Maybank SJ, Tao D (2021) https://arxiv.org/abs/2006.05525
[4] DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter (2019) Sanh V, Debut L, Chammond J, Wolf T. https://arxiv.org/abs/1910.01108v4
[5] Lessons from building acoustic models with a million hours of speech (2019) Parthasarathi SHK, Strom N. https://arxiv.org/abs/1904.01624