Transformer :数学解释为什么缩放点积会导致更稳定的梯度

最近写了太多热点的模型和方法,感觉飘在天上一样,还是想脚踏实地写一些基础理论。看了一些大佬的文章,发现一些不错的内容,特此总结分享。感谢阅读,水平有限~

📖阅读时长:15分钟

🕙发布时间:2025-02-05

近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企
公众号【柏企科技说】【柏企阅文

Transformer网络中使用的自注意力机制,主要目的是生成能考虑到周围单词上下文信息的词嵌入。自注意力机制通过将句子中的每个单词与其他所有单词进行比较,随后把上下文相关的单词组合在一起,以此完成上述任务。

计算第一个单词的自我注意分数

自注意力机制首先会为句子里的每个单词计算三个向量:查询向量(query)、键向量(key)和值向量(value)。为了找出某个选定单词在上下文中相关的单词,我们会将该单词的查询向量,与句子中其他所有单词的键向量进行点乘,见上图。点乘得到的值范围在负无穷到正无穷之间,所以要使用softmax函数将这些值映射到[0, 1]区间,确保它们在整个序列中的总和为1。对于那些和选定单词不相关的单词,得到的自注意力分数会非常小。

不过,为什么在将点乘结果输入softmax函数之前,要用√64进行缩放呢?在大多数关于Transformer的教程里,我们都听说过点乘结果的数值会变得很大,这会把softmax函数推向梯度极小的区域。

在这篇文章中,我们将从数学角度来理解为什么会这样。为此,我们首先探究softmax函数对于大数值输入的表现,接着分析大数值对softmax函数导数的影响。

Softmax 函数

Transformer网络中softmax函数的主要作用,是将一系列任意实数(包括正数和负数)转化为总和为1的正数:
$$Softmax(x_i)=\frac{e^{x_i}}{\sum_{j=1}^{n}e^{x_j}}$$
上述公式中的指数函数确保得到的值是非负的。由于分母中的归一化项,这些值的总和为1。

然而,softmax函数并不具有尺度不变性,

输入缩放得越高,最大的输入就越能主导输出结果。随着缩放比例增加,softmax函数会给最大的输入值分配接近1的值,给其他所有值分配接近0的值。这是由指数函数的特性导致的,指数函数的输入越大,增长速度越快。相反,如果缩小输入的尺度,softmax函数的输出就会变得非常相似。

雅可比矩阵

在继续讲解之前,我们得先明确一点:从形式上来说,softmax是一种向量函数,它以向量作为输入,输出也是向量:
$$\text{softmax}: \mathbb{R}^n \to \mathbb{R}^n$$

因此,当我们讨论softmax函数的导数时,实际上说的是它的雅可比矩阵(而不是梯度),雅可比矩阵是由所有一阶偏导数构成的矩阵:
$$J_{ij}=\frac{\partial s_i}{\partial x_j}$$
其中,$s_i = \text{softmax}(x)_i$。

雅可比矩阵元素的闭式表达式:
$$\frac{\partial s_i}{\partial x_j}=s_i(\delta_{ij}-s_j)$$
这里,$\delta_{ij}$是克罗内克函数(Kronecker delta),当$i = j$时,$\delta_{ij}=1$;当$i \neq j$时,$\delta_{ij}=0$。

看看上面这个公式,$s$的偏导数是如何用$s$本身来表示的。为了更直观地了解雅可比矩阵的完整结构,我们以$n = 4$为例写出来看看:

可以看到,对角线元素和非对角线元素是不同的。此外,softmax函数的雅可比矩阵是对称的。

接下来,我们找一种情况,让雅可比矩阵的所有元素都变为零。很容易发现,当任意一个$s$取值为0或1时,对角线元素就会变为零。当非对角线元素的其中一个或两个因子为零时,非对角线元素也会变为零。所以,在以下四种情况下,雅可比矩阵会变成零矩阵:

对于大的输入,softmax函数生成的输出和上述情况非常相似。

通过Softmax层反向传播

最后,我们来解释为什么softmax函数的大输入值会导致在反向传播过程中梯度消失。

假设,通过反向传播,我们已经计算出了softmax函数输出端的梯度,如下图所示:

接下来,我们想通过softmax函数进行反向传播,得到输入端的梯度。要知道,softmax函数的每个输出都依赖于它的所有输入。根据链式法则,对于第$j$个输入,我们可以得到:
$$\frac{\partial L}{\partial x_j}=\sum_{i = 1}^{n}\frac{\partial L}{\partial s_i}\frac{\partial s_i}{\partial x_j}$$

等式右边的行向量可以看作是雅可比矩阵的第$j$列。所以,通过softmax层进行反向传播,就相当于乘以它的雅可比矩阵:
$$\nabla_x L = J^T \nabla_s L$$

我们已经知道,当softmax函数的输入数值变得很大时,它的雅可比矩阵会趋近于零矩阵。在这种情况下,梯度流(误差传播)就会被softmax层阻断,softmax层之前的所有元素的学习速度都会变慢,甚至完全停止学习。

在Transformer网络中,softmax函数的输入是由键向量和查询向量之间的点积构成的。键向量和查询向量的维度$d$越大,点积的值往往也会越大。在原始论文里,键向量的维度是64。论文作者采用的解决办法是,将点积结果除以查询向量和键向量维度的平方根。这样一来,无论键向量和查询向量的维度是多少,学习过程都能顺利进行。

引用

本文由mdnice多平台发布


柏企科技圈
1 声望0 粉丝

时间差不多了,快上车!~