最近,有粉丝问我,attention结构中计算qkv的时候,为什么要做kvcache呢?他看了一些文章,没看懂。
为什么要做kvcache?
假设模型的输入序列长度是2,隐藏层的维度是H,那么q、k、v的维度分别是[2, H]
假设它们的值分别是:
q=[q1,
q2]
k=[k1,
k2]
v=[v1,
v2]
那么首先q*k的结果为:
[q1*k1, q1*k2
q2*k1, q2*k2]
然后需要做一个mask,只留下下三角的值,其他值都取0,得到:
[q1*k1, 0
q2*k1, q2*k2]
为什么要做mask,我认为是要和训练时的规则保持一致,因为训练的时候,是认为每个token只能看到它前面的词的。
然后计算qk*v:
[q1*k1*v1,
q2*k1*v1+q2*k2*v2]
完成后续的计算可以预测得到1个新的token。如果还需要继续预测下一个词,在下一次计算的时候我们假设q、k、v为:
q=[q1,
q2,
q3]
k=[k1,
k2,
k3]
v=[v1,
v2,
v3]
同样得到q*k为:
[q1*k1, 0, 0
q2*k1, q2*k2, 0,
q3*k1, q3*k2, q3*k3]
qk*v为:
[q1*k1*v1,
q2*k1*v1+q2*k2*v2,
q3*k1*v1+q3*k2*v2+q3*k3*v3]
可以看到,第2次计算得到的qk相比于第1次的qk只是多了第3行。而第3行的值是q3*[k1, k2, k3],所以为了避免重复计算,我们只需要在第2次计算的时候,只计算新token对应的q3和k3,然后把k3和第1次计算得到的[k1, k2]拼接起来即可,[k1, k2]就是 k cache。
同样可以发现,第2次计算得到的qkv相比于第1次的qkv只是多了第3行。而第3行的值是qk*[v1, v2, v3],所以为了避免重复计算,我们只需要在第2次计算的时候,只计算新token对应的v3,然后把v3和第1次计算得到的[v1, v2]拼接起来即可,[v1, v2]就是 v cache。
以此类推,在后续的增量推理过程中,每次只需要计算新token的q、k、v,然后利用之前缓存的kv cache计算qk和qkv。
transformer是怎样预测出下一个词的?
首先,从数学层面来讲,是这样计算的:
首先,假设输入序列的长度是L,隐藏层的特征维度是H,词汇表的长度是V,那么在计算qkv的过程中,输入x的shape变化如下:
q*k:(L, H)x(H, L)->(L, L)
qk*v:(L, L)x(L, H)->(L, H)
然后再经过forward layer的一系列全连接层,得到的输出shape为(L, V),而它的最后一个分量,也就是output[L-1],就是预测结果的概率分布。
那么怎么理解这个计算过程呢?这个就可以有很多答案了,我一般是这么给别人解释的:首先在计算q*k的时候,qk的最后一个分量是用最后一个词去和其他词的key值做乘法,这一步相当于计算最后一个词和句子中每个词的相关性,然后乘以v就相当于把最后一个词和其他词的相关性进行一个组合,后面再通过多个全连接层进行上下文理解这个词在整个句子中的含义,并预测出下一个词。
这里又引入了另一个问题,既然在首次计算时,只用到了最后一个分量,为什么还要计算qk和qkv的第1到第L-1个分量的值呢?这是因为大模型由多个decoder layer叠加组成。第1个decoder输出的结果还需要作为x输入给第2个decoder layer,进行多轮"思考"。再具体一点,我们还是假设输入序列长度是2,经过第1个decoder layer后输出为:
[h1,
h2]
那么它再作为输入传给第2个decoder layer,第2个decoder layer计算得到的qkv是:
[q1*k1*v1,
q2*k1*v1+q2*k2*v2]
它的最后一个分量是q2k1v1+q2k2v2,其中的k1、v1都和h1相关,所以做首次计算(也就是我们常说的全量计算)时,qk和qkv的每个分量都要计算。
大家还有什么疑问呢?欢迎讨论哦!
本文由博客一文多发平台 OpenWrite 发布!
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。