sparse_categorical_crossentropy 和 categorical_crossentropy 有什么区别?

新手上路,请多包涵

sparse_categorical_crossentropycategorical_crossentropy 有什么区别?什么时候应该使用一种损失而不是另一种?例如,这些损失是否适合线性回归?

原文由 xpertdev 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1.1k
2 个回答

简单地:

  • categorical_crossentropy ( cce ) 生成一个单热数组,其中包含每个类别的可能匹配项,
  • sparse_categorical_crossentropy ( scce ) 生成 最可能 匹配类别的类别索引。

考虑具有 5 个类别(或类)的分类问题。

  • cce 的情况下,单热目标可能是 [0, 1, 0, 0, 0] 并且模型可能预测 [.2, .5, .1, .1, .1] (可能是正确的)

  • scce 的情况下,目标索引可能是 [1],模型可能预测:[.5]。

现在考虑一个包含 3 个类别的分类问题。

  • cce 的情况下,单热目标可能是 [0, 0, 1] 并且该模型可以预测 [.5, .1, .4] (可能给出更多概率头等舱)
  • scce 的情况下,目标索引可能是 [0] ,模型可能预测 [.5]

许多分类模型产生 scce 输出,因为你节省了空间,但丢失了很多信息(例如,在第二个例子中,索引 2 也非常接近。)我通常更喜欢 cce 模型可靠性的输出。

有多种情况可以使用 scce ,包括:

  • 当你的课程相互排斥时,即你根本不关心其他足够接近的预测,
  • 类别数量大到预测输出变得不堪重负。

220405 :对“one-hot encoding”评论的回应:

one-hot 编码用于类别特征输入以选择特定类别(例如男性与女性)。这种编码允许模型更有效地训练:训练权重是类别的乘积,对于除给定类别之外的所有类别都为 0。

ccescce 是模型输出。 cce 是每个类别的概率数组,共1.0。 scce 显示最喜欢的类别,总计1.0。

scce 技术上是一个单热阵列,就像用作门挡的锤子仍然是锤子,但其用途不同。 cce 不是一次性的。

原文由 dturvene 发布,翻译遵循 CC BY-SA 4.0 许可协议

我也对这个感到困惑。幸运的是,出色的 keras 文档帮了大忙。两者具有相同的损失函数并且最终做同样的事情,唯一的区别在于真实标签的表示。

  • 分类交叉熵 [ Doc ]:

当有两个或更多标签类别时使用此交叉熵损失函数。我们希望以 one_hot 表示形式提供标签。

 >>> y_true = [[0, 1, 0], [0, 0, 1]]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> cce = tf.keras.losses.CategoricalCrossentropy()
>>> cce(y_true, y_pred).numpy()
1.177

  • 稀疏分类交叉熵 [ Doc ]:

当有两个或更多标签类别时使用此交叉熵损失函数。我们希望标签以整数形式提供。

 >>> y_true = [1, 2]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> scce = tf.keras.losses.SparseCategoricalCrossentropy()
>>> scce(y_true, y_pred).numpy()
1.177

稀疏分类交叉熵的一个很好的例子是 fasion-mnist 数据集。

 import tensorflow as tf
from tensorflow import keras

fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()

print(y_train_full.shape) # (60000,)
print(y_train_full.dtype) # uint8

y_train_full[:10]
# array([9, 0, 0, 3, 0, 2, 7, 2, 5, 5], dtype=uint8)

原文由 Bitswazsky 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题