编者按: 近日,Qwen 3 技术报告正式发布,该系列也采用了从大参数模型中蒸馏知识来训练小参数模型的技术路线。那么,模型蒸馏技术究竟是怎么一回事呢?
今天给大家分享的这篇文章深入浅出地介绍了模型蒸馏的核心原理,即通过让学生模型学习教师模型的软标签而非硬标签,从而传递更丰富的知识信息。作者还提供了一个基于 TensorFlow 和 MNIST 数据集的完整实践案例,展示了如何构建教师模型和学生模型,如何定义蒸馏损失函数,以及如何通过知识蒸馏方法训练学生模型。实验结果表明,参数量更少的学生模型能够达到与教师模型相媲美的准确率。
作者 | Wei-Meng Lee
编译 | 岳扬
Photo by 戸山 神奈 on Unsplash
如果你一直在关注 DeepSeek 的最新动态,可能听说过“模型蒸馏”这个概念。但究竟什么是模型蒸馏?它为何重要?本文将解析模型蒸馏原理,并通过一个 TensorFlow 示例进行演示。通过阅读这篇技术指南,我相信您将对模型蒸馏有更深刻的理解。
01 模型蒸馏技术原理
模型蒸馏通过让较小的、较简单的模型(学生模型)学习模仿较大的、较复杂的模型(教师模型)的软标签(而非原始标签),使学生模型能以更精简的架构继承教师模型的知识,用更少参数实现相近性能。以图像分类任务为例,学生模型不仅学习“某张图片是狗还是猫”的硬标签,还会学习教师模型输出的软标签(如80%狗,15%猫,5%狐狸),从而掌握更细粒度的知识。 这一过程能在保持高准确率的同时大大降低模型体积和计算资源需求。
下文我们将以使用 MNIST 数据集训练卷积神经网络(CNN)为例进行演示。
MNIST 数据集(Modified National Institute of Standards and Technology)是机器学习和计算机视觉领域广泛使用的基准数据集,包含 70,000 张 28x28 像素的手写数字(0-9)灰度图像,其中 60,000 张训练图像和 10,000 张测试图像。
首先构建教师模型:
Image by author
教师模型是基于 MNIST 训练的 CNN 网络。
同时构建更轻量的学生模型:
Image by author
模型蒸馏的目标是通过更少的计算量和训练时间训练一个较小的学生模型,复现教师模型的性能表现。
接下来,教师模型和学生模型同时对数据集进行预测,然后计算二者输出的 Kullback-Leibler (KL) 散度(将于后文进行详述)。该数值(KL 散度)用于计算梯度,指导模型各层参数应该如何调整,从而指导学生模型的参数更新:
Image by author
训练完成后,学生模型达到与教师模型相当的准确率:
Image by author
02 创建一个用于模型蒸馏的示例项目
现在,我们对模型蒸馏的工作原理已经有了更清晰的理解,是时候通过一个简单的示例来了解如何实现模型蒸馏了。我将使用 TensorFlow 和 MNIST 数据集训练教师模型,然后应用模型蒸馏技术训练一个较小的学生模型,使其在保持教师模型性能的同时降低资源需求。
2.1 使用 MNIST 数据集
确保已安装 TensorFlow:
下一步加载 MNIST 数据集:
以下是从 MNIST 数据集中选取的前 9 个样本图像及其标签:
需要对图像数据进行归一化处理,并扩展图像数据的维度,为训练做好准备:
2.2 定义教师模型
现在我们来定义教师模型 —— 一个具有多个网络层的 CNN(卷积神经网络):
请注意,学生模型的最后一层有 10 个神经元(对应 10 个数字类别),但未使用 softmax 激活函数。该层直接输出原始 logits 值,这在模型蒸馏过程中非常重要,因为在模型蒸馏阶段会应用 softmax 计算教师模型与学生模型之间的 Kullback-Leibler(KL)散度。
定义完教师神经网络后,需通过 compile() 方法配置优化器(optimizer)、损失函数(loss function)和评估指标(metric for evaluation):
现在可以使用 fit() 方法训练模型:
本次训练进行了 5 个训练周期:
2.3 定义学生模型
在教师模型训练完成后,接下来定义学生模型。与教师模型相比,学生模型的结构更简单、层数更少:
2.4 定义蒸馏损失函数
接下来定义蒸馏损失函数,该函数将利用教师模型的预测结果和学生模型的预测结果计算蒸馏损失(distillation loss)。该函数需完成以下操作:
- 使用教师模型对当前批次的输入数据进行推理,生成软标签「硬标签:[0, 0, 1](直接指定类别3)。软标签:[0.1, 0.2, 0.7](表示模型认为70%概率是类别3,但保留其他可能性)。」;
- 使用学生模型预测计算其软标签;
- 计算教师模型与学生模型软标签之间的 Kullback-Leibler(KL)散度;
- 返回蒸馏损失。
软标签(soft probabilities)指的是包含多种可能结果的概率分布,而非直接分配一个硬标签。例如在垃圾邮件分类模型中,模型不会直接判定邮件"是垃圾邮件(1)"或"非垃圾邮件(0)",而是输出类似"垃圾邮件概率 0.85,非垃圾邮件概率 0.15"的概率分布。 这意味着模型有 85% 的把握认为该邮件是垃圾邮件,但仍认为有 15% 的可能性不是,从而可以更好地进行决策和阈值调整。
软标签使用 softmax 函数进行计算,并由温度参数(temperature)控制分布形态。在知识蒸馏过程中,教师模型提供的软标签能帮助学生模型学习到数据集各类别间的隐含关联,从而获得更优的泛化能力和性能表现。
以下是 distillation_loss() 函数的具体定义:
Kullback-Leibler(KL)散度 (又称相对熵)是衡量两个概率分布差异程度的数学方法。
2.5 使用知识蒸馏方法训练学生模型
现在我们可以通过知识蒸馏训练学生模型了。首先定义 train_step() 函数:
该函数只执行了一个训练步骤:
- 计算学生模型的预测结果
- 利用教师模型的预测结果计算蒸馏损失
- 计算梯度并更新学生模型的权重
要对学生模型进行训练,需要创建一个训练循环(training loop)来遍历数据集,每一步都会更新学生模型的权重,并在每个 epoch 结束时打印损失值以监测训练进度:
2.6 评估学生模型
训练完成后,你可以使用测试集(x_test 和 y_test)评估学生模型的表现:
不出所料,学生模型的准确率相当高:
2.7 使用教师模型和学生模型进行预测
现在可以使用教师模型和学生模型对 MNIST 测试集的数字进行预测,观察两者的预测能力:
前两个样本的预测结果如下:
若测试更多数字图像样本,你会发现学生模型的表现与教师模型同样出色。
03 Summary
在本文,我们探讨了模型蒸馏(Model Distillation)这一概念,这是一种让结构更简单、规模更小的学生模型复现或逼近结构更复杂的教师模型的性能的技术。我们利用 MNIST 数据集训练教师模型,然后应用模型蒸馏技术训练学生模型。最终,层数更少、结构更精简的学生模型成功复现了教师模型的性能表现,同时还大大降低了计算资源的需求。
希望这篇文章能够满足各位读者对模型蒸馏技术的好奇心,也希望本文提供的示例代码可以直观展现该技术的高效与实用。
About the author
Wei-Meng Lee
ACLP Certified Trainer | Blockchain, Smart Contract, Data Analytics, Machine Learning, Deep Learning, and all things tech (http://calendar.learn2develop.net).
END
本期互动内容 🍻
❓除了模型蒸馏,剪枝和量化也是常用的模型压缩方法。在你们的项目中,更倾向于采用哪些方法? 欢迎在评论区分享~
本文经原作者授权,由 Baihai IDP 编译。如需转载译文,请联系获取授权。
原文链接:
https://ai.gopubby.com/understanding-model-distillation-991ec...
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。