Softmax 函数和分类交叉熵损失的导数
最近写了太多热点的模型和方法,感觉飘在天上一样,还是想脚踏实地写一些基础理论。看了一些大佬的文章,发现一些不错的内容,特此总结分享。感谢阅读,水平有限~
📖阅读时长:15分钟
🕙发布时间:2025-02-05
近日热文:全网最全的神经网络数学原理(代码和公式)直观解释
欢迎关注知乎和公众号的专栏内容
LLM架构专栏
知乎LLM专栏
知乎【柏企】
公众号【柏企科技说】【柏企阅文】
在这篇简短的文章中,我们将计算softmax函数的雅可比矩阵。通过应用一个巧妙的计算技巧,我们将使推导过程变得非常简短。然后,利用得到的雅可比矩阵,我们将计算分类交叉熵损失的梯度。
Softmax函数
softmax函数的主要目的是获取一个任意实数向量,并将其转换为概率:
上面公式中的指数函数确保得到的值是非负的。由于分母中的归一化项,得到的值总和为1。此外,所有值都在0到1之间。softmax函数的一个重要特性是,它保留了输入值的排序顺序。
Softmax函数的雅可比矩阵
从形式上讲,softmax函数是一种所谓的向量函数,它以向量作为输入,并生成向量作为输出:
因此,当我们谈论softmax函数的导数时,实际上我们谈论的是它的雅可比矩阵,即所有一阶偏导数组成的矩阵:
注意,softmax函数的每个输出是如何依赖于所有输入值的(这是由于分母的缘故)。因此,雅可比矩阵的非对角元素不为零。
由于softmax函数的输出是严格正值,我们可以通过应用以下技巧,使下面的推导过程非常简短:我们不直接对输出求偏导数,而是对输出的对数求偏导数(也称为“对数导数”):
其中,右侧的表达式直接由链式法则得出。接下来,我们重新整理上面的公式,得到:
左边正是我们要找的偏导数。正如我们很快会看到的,右边简化了导数的计算,这样我们就不需要使用导数的商法则。我们首先要对$s$取对数:
得到的表达式的偏导数为:
让我们看一下右边的第一项:
它可以用指示函数$1\{·\}$简洁地表示。如果指示函数的参数为真,它的值为1,否则为0。
右边的第二项可以通过应用链式法则来计算:
在上面的步骤中,我们使用了自然对数的导数:
求总和的偏导数很简单:
将结果代入公式,得到:
最后,我们必须将上面的表达式与$s$相乘,如本节开头所示:
我们的推导到此结束。我们得到了雅可比矩阵所有元素(对角元素和非对角元素)的公式。对于$n = 4$的特殊情况,我们得到:
看看对角元素和非对角元素有何不同。
分类交叉熵损失
分类交叉熵损失与softmax函数密切相关,因为它实际上只用于输出层带有softmax层的网络。在我们正式介绍分类交叉熵损失(通常也称为softmax损失)之前,我们需要先简要说明两个术语:多类分类和交叉熵。
分类问题可以细分为以下两类:
- 多类分类:每个样本只属于一个类别(互斥)
- 多标签分类:每个样本可能属于多个类别(或不属于任何类别)
分类交叉熵损失专门用于多类分类任务,其中每个样本恰好属于$C$个类别中的一个。因此,分配给每个样本的真实标签由一个介于0到$C - 1$之间的整数值组成。标签可以用一个大小为$C$的独热编码向量来表示,该向量在正确的类别位置上的值为1,其他位置的值为0。例如,当$C = 4$时,如下所示:
交叉熵将两个离散概率分布(简单来说,就是元素在0到1之间且总和为1的向量)作为输入,并输出一个实数值(!),表示这两个概率分布的相似程度:
其中,$C$表示不同类别的数量,下标$i$表示向量的第$i$个元素。交叉熵越小,两个概率分布就越相似。
当交叉熵在多类分类任务中用作损失函数时,$y$被输入独热编码标签,softmax层生成的概率被放入$s$中。这样,我们就不会对零取对数,因为从数学上讲,softmax永远不会真正产生零值。
通过在训练过程中最小化损失,我们基本上是迫使预测概率逐渐接近真实的独热编码向量。
为了启动反向传播过程(如本文所述),我们必须计算损失相对于输出层加权输入$z$的导数,见上图:
让我们代入上一节中得到的导数:
并展开最后一项中的乘积:
指示函数$1\{·\}$在$i = j$时取值为1,在其他地方取值为0:
接下来,我们将$s$从总和中提出,因为它不依赖于索引$i$:
在最后一步中,我们利用了独热编码向量$y$的总和为1这一事实。请记住,独热编码向量可以被解释为一种概率分布,其概率质量集中在单个值上。用简洁的向量表示法,我们得到:
推荐阅读
1. DeepSeek-R1的顿悟时刻是如何出现的? 背后的数学原理
2. 微调 DeepSeek LLM:使用监督微调(SFT)与 Hugging Face 数据
3. 使用 DeepSeek-R1 等推理模型将 RAG 转换为 RAT
4. DeepSeek R1:了解GRPO和多阶段训练
5. 深度探索:DeepSeek-R1 如何从零开始训练
6. DeepSeek 发布 Janus Pro 7B 多模态模型,免费又强大!
本文由mdnice多平台发布
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。