带有标签的sklearn plot混淆矩阵

新手上路,请多包涵

我想绘制一个混淆矩阵来可视化分类器的性能,但它只显示标签的数量,而不是标签本身:

 from sklearn.metrics import confusion_matrix
import pylab as pl
y_test=['business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business', 'business']

pred=array(['health', 'business', 'business', 'business', 'business',
       'business', 'health', 'health', 'business', 'business', 'business',
       'business', 'business', 'business', 'business', 'business',
       'health', 'health', 'business', 'health'],
      dtype='|S8')

cm = confusion_matrix(y_test, pred)
pl.matshow(cm)
pl.title('Confusion matrix of the classifier')
pl.colorbar()
pl.show()

如何将标签(健康、商业等)添加到混淆矩阵中?

原文由 hmghaly 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 725
2 个回答

正如 这个问题 所暗示的,您必须通过存储您调用的 matplotlib 函数传递的图形和轴对象来“打开” 较低级别的艺术家 APIfigaxcax 下面的变量)。然后,您可以使用 set_xticklabels / set_yticklabels 替换默认的 x 轴和 y 轴刻度:

 from sklearn.metrics import confusion_matrix

labels = ['business', 'health']
cm = confusion_matrix(y_test, pred, labels)
print(cm)
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(cm)
plt.title('Confusion matrix of the classifier')
fig.colorbar(cax)
ax.set_xticklabels([''] + labels)
ax.set_yticklabels([''] + labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

请注意,我将 labels 列表传递给了 confusion_matrix 函数以确保它正确排序,匹配刻度。

结果如下图:

在此处输入图像描述

原文由 metakermit 发布,翻译遵循 CC BY-SA 3.0 许可协议

更新:

在 scikit-learn 0.22 中,有一个新功能可以直接绘制混淆矩阵(但是,在 1.0 中已弃用,并将在 1.2 中删除)。

请参阅文档: sklearn.metrics.plot_confusion_matrix


旧答案:

我认为值得一提的是 seaborn.heatmap 的使用。

 import seaborn as sns
import matplotlib.pyplot as plt

ax= plt.subplot()
sns.heatmap(cm, annot=True, fmt='g', ax=ax);  #annot=True to annotate cells, ftm='g' to disable scientific notation

# labels, title and ticks
ax.set_xlabel('Predicted labels');ax.set_ylabel('True labels');
ax.set_title('Confusion Matrix');
ax.xaxis.set_ticklabels(['business', 'health']); ax.yaxis.set_ticklabels(['health', 'business']);

在此处输入图像描述

原文由 akilat90 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题