知识蒸馏:TinyBERT如何为自然语言理解提炼BERT知识


📖阅读时长:19分钟

🕙发布时间:2025-02-13

近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企
公众号【柏企科技说】【柏企阅文

知识蒸馏旨在将大型教师网络T的知识转移到小型学生网络S中。用$f_T$和$f_S$分别表示教师网络和学生网络的行为函数。

在Transformer蒸馏的情境下,多头注意力机制(MHA)层、前馈神经网络(FFN)层的输出,或者一些中间表示(比如注意力矩阵A)都可以作为行为函数。形式上,知识蒸馏(KD)可以建模为最小化以下目标函数:

假设学生模型有(M)个Transformer层,教师模型有(N)个Transformer层。我们首先从教师模型的(N)个层中选择(M)个层用于Transformer层蒸馏。然后定义一个函数(n = g(m)),它是学生层索引和教师层索引之间的映射函数。

因此,学生网络可以通过最小化以下目标来从教师网络获取知识:

Transformer层蒸馏

其中,(h)是注意力头的数量,(A_i)指第(i)个注意力头对应的注意力矩阵。

其中,矩阵(H^S)和(H^T)分别指学生网络和教师网络的隐藏状态。矩阵(W_h)是一个可学习的线性变换,它将学生网络的隐藏状态转换到与教师网络状态相同的空间。

嵌入层蒸馏

其中,矩阵(E^S)和(E^T)分别指学生网络和教师网络的嵌入。矩阵(W_e)是一个线性变换,其作用与(W_h)类似。

预测层蒸馏

其中,(z^T)和(z^S)分别是教师网络和学生网络预测的对数几率(logits)向量,(t)表示温度值。在实验中发现,(t = 1)时效果良好。

统一蒸馏损失

利用上述蒸馏目标,我们可以统一教师网络和学生网络相应层之间的蒸馏损失:

TinyBERT学习

TinyBERT提出了一种新颖的两阶段学习框架,包括通用蒸馏和任务特定蒸馏。

通用蒸馏帮助TinyBERT学习预训练BERT中蕴含的丰富知识,这对提升TinyBERT的泛化能力至关重要。任务特定蒸馏则进一步让TinyBERT学习微调后的BERT中的知识。

TinyBERT设置

TinyBERT4

  • 学生模型:TinyBERT4((M = 4),(d = 312),(d’ = 1200),(h = 12))共有1450万个参数
  • 教师模型:BERT BASE((M = 12),(d = 768),(d’ = 3072),(h = 12))共有1.09亿个参数
  • (g(m) = 3m),(\lambda = 1)

TinyBERT6

  • 学生模型:TinyBERT6((M = 6),(d = 768),(d’ = 3072),(h = 12))共有1450万个参数
  • 教师模型:BERT BASE((M = 12),(d = 768),(d’ = 3072),(h = 12))共有1.09亿个参数
  • (g(m) = 2m),(\lambda = 1)

结果

论文

TinyBERT: Distilling BERT for Natural Language Understanding 1909.10351

## 推荐阅读
1. DeepSeek-R1的顿悟时刻是如何出现的? 背后的数学原理
2. 微调 DeepSeek LLM:使用监督微调(SFT)与 Hugging Face 数据
3. 使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT
4. DeepSeek R1:了解GRPO和多阶段训练
5. 深度探索:DeepSeek-R1 如何从零开始训练
6. DeepSeek 发布 Janus Pro 7B 多模态模型,免费又强大!

本文由mdnice多平台发布


柏企科技圈
15 声望5 粉丝