深度剖析DeepSeek创新注意力机制-MLA

📖阅读时长:19分钟

🕙发布时间:2025-02-17

近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企
公众号【柏企科技说】【柏企阅文

DeepSeek模型在多个基准测试(教育、事实性、数学推理、编程等领域)中都取得了顶尖成绩,甚至能与OpenAI的o1模型一较高下。DeepSeek是如何做到在低资源消耗下还能表现卓越的呢?这就要提到其突破性的架构变革——多头潜在注意力机制(Multi - Head Latent Attention,简称MLA),这一机制是DeepSeek模型成功的关键因素之一。

从“注意力”机制说起

在大语言模型的世界里,“注意力”机制至关重要。简单来说,它能让模型在做预测时,把重点放在输入序列的不同部分。就像我们阅读文章时,会不自觉地关注关键语句一样。注意力机制会衡量序列中每个标记(token)的重要性,不管它们之间的距离有多远,都能捕捉到它们之间的关系,从而帮助模型判断哪些输入标记和当前正在处理的标记最相关。

其实,注意力机制并不是什么新鲜事物。早在循环神经网络(RNN)用于神经机器翻译任务时,它就被广泛应用了,当时的Bahdanau注意力机制就是其中代表。这个机制利用双向RNN作为编码器,处理输入序列$x(1)$到$x(T)$,生成隐藏状态$h(1)$到$h(T)$。注意力机制会计算每个编码器隐藏状态的注意力分数,通过softmax函数把这些分数转化为注意力权重$a(t,1)$到$a(t,T)$($T$是输入令牌的总数) 。然后,根据这些权重对编码器的隐藏状态进行加权求和,得到上下文向量$c(t)$。在每个时间步$t$,解码器会结合当前隐藏状态$s(t)$、上下文向量$c(t)$、前一个隐藏状态$s(t - 1)$和前一个输出$y(t - 1)$,生成下一个输出$y(t)$ 。

后来,在Transformer架构中,又引入了一种特殊的注意力机制——缩放点积注意力(Scaled Dot - Product Attention)。它基于输入令牌嵌入得到的三个值来计算:查询(Query,简称Q),代表模型当前正在处理的令牌向量;键(Key,简称K),代表序列中每个令牌的向量;值(Value,简称V),包含与每个令牌相关联信息的向量。计算公式为:$Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V$,其中$d_k$是键向量的维度。

Transformer架构在使用缩放点积注意力机制时,主要有三种方式:

  1. 自注意力(Self - Attention):用于Transformer的编码器。在这里,查询、键和值都来自同一个输入序列,也就是编码器前一层的输出。
  2. 掩码自注意力(Masked Self - Attention):用于Transformer的解码器。查询、键和值同样来自同一个序列,但会把未来的令牌进行掩码处理。
  3. 交叉注意力(Cross - Attention)或编码器 - 解码器注意力(Encoder - Decoder Attention):也用于Transformer的解码器。此时,查询来自解码器的前一层,而键和值来自编码器的输出 。

迈向多头注意力机制

为了更好地捕捉输入序列的语义关系,Transformer并没有满足于只计算一次注意力分数,而是引入了多头注意力机制(Multi - Head Attention,简称MHA)。它通过多个“头”并行运行注意力机制,每个头都能关注输入序列的不同方面,比如短距离和长距离依赖关系、语法规则等。

假设Transformer架构的整体模型维度为$d(model)$,代表输入/隐藏层表示$X$的维度。在多头注意力机制中,每个头不会直接处理全维度向量,而是使用查询、键和值的低维投影。通过学习得到的投影矩阵$W(Q)(i)$、$W(K)(i)$和$W(V)(i)$,将$X$投影到低维向量中,其中$d(k) = d(v) = d(model)/h$($h$是头的数量)。接下来,每个头会用自己的投影矩阵计算缩放点积注意力分数,最后把这些分数连接起来,再通过学习矩阵$W(O)$进行线性变换。在Transformer架构中,无论是自注意力还是交叉注意力,实际使用的都是多头注意力机制,而不是基本的注意力机制。

不过,多头注意力机制虽然强大,但在推理时存在内存消耗大的问题。当大语言模型生成一个令牌时,需要计算与之前所有令牌的注意力分数。为了加快推理速度,通常会把之前令牌的键和值存储在键值缓存(Key - Value Cache,简称KV缓存)中,查询则根据每个新令牌动态计算。但随着序列长度增加,缓存的键值对数量也会线性增长。对于一个有$L$层、每层$n(h)$个头、每个头维度为$d(h)$的Transformer模型,每个令牌需要缓存$2×n(h)×d(h)×L$个元素 。这样一来,缓存会占用大量GPU内存,尤其在长上下文模型中,会导致缓存检索时GPU内存吃紧,推理速度变慢。

升级到多查询注意力机制(MQA)

为了解决多头注意力机制在推理时的内存问题,多查询注意力机制(Multi - Query Attention,简称MQA)应运而生。

与MHA不同,MQA在所有头之间共享一组键和值,而不是每个头都缓存单独的键值对。这样,KV缓存的大小就从$2×n(h)×d(h)×L$(MHA中的大小)减少到了$2×d(h)×L$ 。因为只需要获取一组键和值,所以GPU内存使用量大大降低,在推理时可以处理更大的批次。像PaLM和Falcon等大语言模型,就采用了MQA来替代MHA。

但MQA也并非完美无缺。由于所有头共享相同的键和值,这使得模型学习到的有效表示减少,大语言模型的表达能力变弱,难以捕捉长距离依赖关系。而且在大语言模型微调时,尤其是处理长输入任务,使用MQA还会导致模型不稳定。

分组查询注意力机制(GQA)来“救场”

有没有一种方法能在减少KV缓存大小的同时,还能保证模型性能呢?谷歌研究团队提出的分组查询注意力机制(Grouped - Query Attention,简称GQA)给出了答案。
GQA是MHA和MQA之间的一种权衡方案。它不像MHA那样每个头都有一个KV对,也不像MQA那样所有头共享一个KV对,而是把多个头分成一组,每组共享一个KV对 。每个组处理自己的查询,但共享相同的键和值。这样,KV缓存大小就变为$2×d(h)×G$($G$是组数) 。这种方式不仅加快了推理速度,还能让大语言模型有效地学习表示。

总结一下这几种注意力机制的特点:MHA生成质量最好,但推理速度最慢;MQA推理速度最快,但生成质量最低;GQA则在两者之间找到了平衡。当$G = 1$时,GQA的工作方式和MQA类似;当$G = n(h)$时,又和MHA相似。研究发现,当$G$取值在4 - 8之间时,GQA性能最佳。像Llama 3系列模型、Mistral 7B和DeepSeek LLM(DeepSeek模型的第一个版本)等流行的大语言模型,都在架构中采用了GQA。而DeepSeek后续版本的模型,更是在此基础上进一步提升了性能。

多头潜在注意力机制(MLA)登场

DeepSeek进一步优化,推出了多头潜在注意力机制(MLA)。MLA旨在进一步缩小KV缓存的大小,同时在性能上超越之前提到的注意力机制(包括MHA)。它通过将KV缓存压缩到低维潜在空间,成功将缓存大小减小了93.3% !下面我们详细看看它是如何做到的。

  1. 低秩键值联合压缩:MLA不会像传统方式那样计算和存储每个令牌的键和值,而是使用下投影矩阵$W(DKV)$把它们压缩成潜在向量$C(KV)$。在推理时,再通过每个头的上投影矩阵$W(UK)$(用于键)和$W(UV)$(用于值)从这个潜在向量中重建KV对。为了降低计算成本,MLA还进行了巧妙的优化:把矩阵$W(UK)$合并到$W(Q)$中,这样就不用显式计算键$K(i)$了;把矩阵$W(UV)$合并到$W(O)$中,也就无需显式计算值$V(i)$了。

  1. 查询的低秩压缩:MLA对查询也进行了类似的压缩。

使用下投影矩阵$W(DQ)$将查询压缩成潜在表示$C(Q)$,需要时再用上投影矩阵$W(UQ)$进行重建。虽然这样做不会减少KV缓存的大小,但能降低训练期间的激活内存使用。(激活内存是训练过程中前向传播时用于存储中间激活的内存,反向传播计算梯度时会用到这些激活。)在使用MHA训练时,每一层都会在内存中显式计算和存储查询,且数量会随着层数线性增加。而在MLA中,只存储查询的压缩表示,减少了反向传播时存储的总激活量。不过要注意,在推理时,每个令牌计算一次查询后就会丢弃,不会存储用于反向传播的激活。所以,查询压缩主要是提高了训练效率,对推理性能没有影响。

研究人员尝试在MLA中使用旋转位置嵌入(RoPE)来加入令牌位置信息,

可这遇到了一些问题。在深入探讨之前,我们先来了解一下位置编码在大语言模型中的工作原理。

Transformer架构并行处理令牌,这虽然让它比RNN在计算上更有优势,但也导致它对令牌顺序不敏感。比如,“The cat sits on the mat.”和“The mat sites on the cat.”这两句话,对Transformer来说没什么区别。但在语言处理中,顺序很重要,所以需要添加位置信息。位置嵌入主要有两种类型:绝对位置嵌入,给每个令牌根据其位置分配唯一编码;相对位置嵌入,编码的是令牌之间的相对距离,而不是绝对位置。这两种嵌入又可以分为固定的(用数学函数计算)和可学习的(模型训练时通过反向传播更新参数)。在原始的Transformer论文中,作者使用的是固定的绝对位置嵌入,通过交替的正弦和余弦函数在偶数和奇数维度上计算位置嵌入$PE$,公式为:$PE(pos,2i)=sin(pos/10000^{2i/d(model)})$,$PE(pos,2i + 1)=cos(pos/10000^{2i/d(model)})$,其中$pos$是令牌索引,$i$是令牌嵌入维度的索引,$d(model)$是总令牌嵌入维度。这些位置嵌入和令牌嵌入维度相同,可以直接相加后再输入Transformer进行处理。

后来,2023年的一项研究提出了旋转位置嵌入(RoPE),这是一种在注意力机制中直接编码绝对和相对位置的新方法。RoPE不会像之前那样添加位置嵌入,而是根据令牌的位置旋转令牌嵌入。具体来说,对于位置$m$处维度为$d$的令牌嵌入$x(m)$,分别使用权重矩阵$W(q)$和$W(k)$将其转换为查询向量$q(m)$和键向量$k(n)$ 。在进行自注意力计算前,使用与位置相关的旋转矩阵$R(m)$对这些向量进行旋转。$R(m)$会独立作用于$q$和$k$中的每对维度。以二维向量为例,旋转矩阵$R(m)$定义为:$\begin{bmatrix}cos(m\theta)& -sin(m\theta)\\sin(m\theta)&cos(m\theta)\end{bmatrix}$,这个矩阵会将向量逆时针旋转,旋转角度与位置$m$成正比,为$m\theta$($d = 2$时,$\theta = 1$ )。对于更高维的向量(假设维度为偶数),会将相邻的维度两两配对,分别进行二维旋转。通过这种方式,RoPE可以让注意力分数编码令牌的相对位置,而且还能体现出相距较远的令牌之间联系的相对重要性低于较近的令牌。

为什么RoPE与MLA不兼容?

回到MLA,我们知道它通过创建键和值的潜在压缩表示$C(KV)$来减少内存使用和提高推理效率。而RoPE需要在计算注意力分数前,根据位置信息用旋转矩阵$R(m)$旋转查询和键。但由于MLA存储的是压缩的键值缓存,不是完整的键,所以如果应用RoPE,每次生成新令牌时都得重新计算所有之前的键,这就破坏了使用压缩KV表示带来的效率提升。另外,之前为了优化,MLA把键向上投影矩阵$W(UK)$合并到了$W(Q)$中,而RoPE的旋转操作会导致矩阵乘法不满足交换律,使得$W(UK)$无法像原来那样与$W(Q)$解耦和合并。

那么RoPE如何在MLA中使用呢?

在MLA中,研究人员引入了一种新方法——解耦旋转位置嵌入(Decoupled Rotary Position Embedding)。首先,计算两种类型的键:一种是之前讨论过的压缩键$K(C)$;另一种是位置敏感或解耦键$K(R)$,它是未压缩的键,用于存储应用RoPE所需的位置信息。查询也会进行类似计算,得到潜在查询$Q(C)$和用于RoPE的位置敏感或解耦查询$Q(R)$。这些计算是在推理时进行的,不会存储。这种方法既保留了低秩KV压缩的优势,又能通过单独存储$K(R)$来应用位置敏感变换,还不会影响RoPE的注意力计算。在这种方式下,每个令牌需要缓存$K(R)$和$C(KV)$,总共缓存$[d(c) + d(h)(R)]×L$个元素($d(c)$是潜在密钥维度,$d(h)(R)$是解耦密钥的每个头维度,$L$是MLA中的层数) ,相比传统Transformer模型,效率大大提高。最后,使用压缩和位置敏感的查询和键来计算注意力分数,得到最终输出。

MLA到底有多厉害?

MLA在KV缓存中存储的元素数量很少,相当于只有2.25个组的GQA,但性能却比MHA更强。


在训练不同的模型时,实验结果很能说明问题:当训练一个总参数为7B的密集模型时,在测试基准上,MHA的表现明显优于GQA和MQA,


这表明GQA和MQA为了减小KV缓存大小牺牲了部分性能。然而,令人惊讶的是,当评估两个总参数分别为16B和250B的混合专家(MoE)模型时,MLA在大多数基准测试中都超过了MHA。而且,MLA实现这种高性能的同时,KV缓存使用量大幅降低,小型MoE模型的KV缓存使用率仅为14%,大型MoE模型使用MLA训练时的KV缓存使用率比MHA低4% 。这充分说明,MLA不仅没有因为减小KV缓存而牺牲性能,反而同时提升了性能和效率。MLA也是DeepSeek - V2在MMLU性能提升、训练成本降低、KV缓存变小和生成吞吐量提高等方面的重要原因之一。与DeepSeek 67B(DeepSeek的第一个大语言模型,是使用分组查询注意力(GQA)和RoPE嵌入训练的密集模型)相比,DeepSeek - V2的进步十分显著。同时,MLA也是DeepSeek - V3在各种语言、编程和数学基准测试中表现优异的关键因素。

参考

DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model

DeepSeek-V3 Technical Report
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Fast Transformer Decoding: One Write-Head is All You Need

RoFormer: Enhanced Transformer with Rotary Position Embedding

## 推荐阅读
1. DeepSeek-R1的顿悟时刻是如何出现的? 背后的数学原理
2. 微调 DeepSeek LLM:使用监督微调(SFT)与 Hugging Face 数据
3. 使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT
4. DeepSeek R1:了解GRPO和多阶段训练
5. 深度探索:DeepSeek-R1 如何从零开始训练
6. DeepSeek 发布 Janus Pro 7B 多模态模型,免费又强大!

本文由mdnice多平台发布


柏企科技圈
15 声望4 粉丝