混淆矩阵不支持多标签指示器

新手上路,请多包涵

multilabel-indicator is not supported 是我在尝试运行时收到的错误消息:

confusion_matrix(y_test, predictions)

y_test 是一个 DataFrame 其形状为:

 Horse | Dog | Cat
1       0     0
0       1     0
0       1     0
...     ...   ...

predictions 是一个 numpy array

 [[1, 0, 0],
 [0, 1, 0],
 [0, 1, 0]]

我搜索了一些错误消息,但还没有真正找到我可以应用的东西。有什么提示吗?

原文由 Khaine775 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1k
2 个回答

不,您对 confusion_matrix 的输入必须是预测列表,而不是 OHE(一个热编码)。在你的 y_testy_pred argmax --- ,你应该得到你所期望的。

 confusion_matrix(
    y_test.values.argmax(axis=1), predictions.argmax(axis=1))

array([[1, 0],
       [0, 2]])

原文由 cs95 发布,翻译遵循 CC BY-SA 4.0 许可协议

混淆矩阵采用标签向量(不是单热编码)。你应该跑

confusion_matrix(y_test.values.argmax(axis=1), predictions.argmax(axis=1))

原文由 Joshua Howard 发布,翻译遵循 CC BY-SA 3.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题