DeepSeek-R1背后关键——多头潜在注意力机制(MLA),现在也能轻松移植到其他模型了!
而且只需原始数据的0.3%~0.6%。
这项研究由复旦大学、华东师范大学、上海AI Lab等联合提出,复旦教授邱锡鹏(Moss大模型项目负责人)也在作者名单之列。
他们提出了MHA2MLA这种数据高效的微调方法,使基于MHA(多头注意力)的大语言模型(LLMs)能够顺利转换到MLA架构。
以Llama2-7B为例,MHA2MLA在降低推理成本(如减少KV缓存大小92.19%)的同时,能将性能损失控制在较小范围(如LongBench性能仅下降0.5%)。
具体咋回事,下面我们接着看。
掌握DeepSeek核心秘诀
多头注意力MHA(Multi-Head Attention)是Transformer架构中的一个核心组件,允许模型同时关注输入的不同部分,每个注意力头都独立地学习输入序列中的不同特征。
然而,随着序列长度的增长,键值(Key-Value,KV)缓存的大小也会线性增加,这给模型带来了显著的内存负担。
为了解决MHA在高计算成本和KV缓存方面的局限性,DeepSeek突破性地引入了多头潜在注意力机制MLA。
简单说,MLA最大创新之处在于:
利用低秩联合压缩键值技术,减少了推理时的KV缓存,从而在保持性能的同时显著降低内存占用。
这一技术也被视为DeepSeek-V3、DeepSeek-R1等当红炸子鸡模型背后的关键。
而现在,为了进一步降低其他LLMs的推理成本,研究人员开发了一种能将采用MHA的模型快速适配MLA架构的方法——MHA2MLA。
这一数据微调方法包含两个关键部分:
- partial-RoPE,即从对注意力分数贡献较小的查询和键的维度中移除旋转位置嵌入(RoPE);
- 低秩近似,基于预训练的键和值参数引入联合奇异值分解(SVD)近似。
先说第一个。Transformer架构中,RoPE(旋转位置编码,Rotary Position Embedding) 通过旋转操作将位置信息融入查询向量Q和键向量K ,帮助模型捕捉序列位置关系。
但研究发现,在计算注意力分数时,并非所有维度的RoPE对结果贡献相同。
换句话说,即使去除那些对注意力分数影响较小的部分维度的RoPE,理论上不会对模型理解上下文的能力造成关键影响。
基于此,研究人员通过计算敏感度指标来确定哪些维度的RoPE贡献较小。
具体而言,对于每个维度,计算RoPE变化时注意力分数的变化程度。一旦变化程度低于特定阈值的维度,即被判定为对注意力分数贡献小。在后续计算中,这些维度将不再应用RoPE。
最终实验证明,partial-RoPE这一策略在不显著影响模型性能的前提下,减少了计算量。
再说低秩近似策略。
该方法基于预训练的键和值参数,引入联合奇异值分解(SVD)近似。
SVD是一种矩阵分解技术,通过对键值矩阵进行SVD分解,可以用低秩矩阵近似原始矩阵,从而减少参数数量。
具体实现中,研究人员首先提取预训练模型中的键和值参数矩阵,对这些矩阵进行联合SVD分解;然后根据模型的性能和压缩需求,构建低秩近似矩阵,用这些低秩近似矩阵替代原始的键值矩阵参与后续计算。
最终结果显示,此举有效降低了模型推理时的计算量和内存占用。
性能几乎不变,将Llama2 KV缓存减少90%以上
实验环节也验证了MHA2MLA方法的有效性。
能在显著降低推理成本的同时,保持甚至提升模型性能。
研究人员选取了用MHA或GQA预先训练的不同规模(135M-7B)的LLMs,然后设置了对照组。
一组是基于传统MHA的原始模型,用于直接对比MHA2MLA方法在相同任务和数据集上的性能表现;另一组是采用分组查询注意力(GQA)的模型,GQA作为MHA的变体,在一定程度上优化了计算成本,将其与MHA2MLA对比,能更清晰地展现MHA2MLA的优势。
在评估其常识性推理能力的六个基准测试中,研究发现:
与原始LLMs性能相比,四个基础模型的性能变化极小,135M模型性能下降0.25%,360M、1B7和7B模型分别有0.03% 、0.03%和0.37%的性能提升或保持。
这表明微调数据未显著影响原模型性能,MHA2MLA能有效实现架构迁移,而且微调数据仅需预训练数据的0.3%-0.6%。
甚至,较大模型在转换到MLA架构时性能下降更少,这说明这一方法对规模更大的模型更有效。
此外,在长文本生成能力评估中,以LongBench为基准,MHA2MLA相比训练后量化方法,在压缩率和精度平衡上表现出色。
当dkv=16时,MHA2MLA可实现87.5%的压缩率,精度损失仅3%;与4-bit量化结合后,压缩率可达92.19%(dkv=64 + Int4HQQ)和96.87%(dkv=16 + Int4HQQ),精度损失分别为-0.5%和-3.2%,优于所有2-bit量化的基线模型。
这也反映了MHA2MLA方法能够与量化技术良好兼容。
综合以上实验,可以看到以Llama2-7B为例,MHA2MLA在降低推理成本(如减少KV缓存大小92.19%)的同时,能将性能损失控制在较小范围(如LongBench性能仅下降0.5%)。
不过,论文也提到了研究局限性。
受计算资源限制,未在更大、更多样化的开源大语言模型上验证MHA2MLA;且由于Deepseek未开源MLA的张量并行推理框架,难以探索大于7B的模型。
下一步,研究人员计划在更多模型上进行验证。
感兴趣的童鞋可以查看原论文~
论文: https://arxiv.org/abs/2502.14837 代码: https://github.com/JT-Ushio/M...
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。