如何绘制混淆矩阵?

新手上路,请多包涵

我正在使用 scikit-learn 将文本文档(22000)分类为 100 个类。我使用 scikit-learn 的混淆矩阵方法来计算混淆矩阵。

 model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm=metrics.confusion_matrix(test_labels,pred)
print(cm)
plt.imshow(cm, cmap='binary')

这就是我的混淆矩阵的样子:

 [[3962  325    0 ...,    0    0    0]
 [ 250 2765    0 ...,    0    0    0]
 [   2    8   17 ...,    0    0    0]
 ...,
 [   1    6    0 ...,    5    0    0]
 [   1    1    0 ...,    0    0    0]
 [   9    0    0 ...,    0    0    9]]

但是,我没有收到清晰或清晰的情节。有一个更好的方法吗?

原文由 minks 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 697
1 个回答

在此处输入图像描述

你可以使用 plt.matshow() 代替 plt.imshow() 或者你可以使用seaborn模块的 heatmap见文档 矩阵)混淆

import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3],
        [3,31,0,0,0,0,0,0,0,0,0],
        [0,4,41,0,0,0,0,0,0,0,1],
        [0,1,0,30,0,6,0,0,0,0,1],
        [0,0,0,0,38,10,0,0,0,0,0],
        [0,0,0,3,1,39,0,0,0,0,4],
        [0,2,2,0,4,1,31,0,0,0,2],
        [0,1,0,0,0,0,0,36,0,2,0],
        [0,0,0,0,0,0,1,5,37,5,1],
        [3,0,0,0,0,0,0,0,0,39,0],
        [0,0,0,0,0,0,0,0,0,0,38]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
                  columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)

原文由 bninopaul 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题