我正在使用 scikit-learn 将文本文档(22000)分类为 100 个类。我使用 scikit-learn 的混淆矩阵方法来计算混淆矩阵。
model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm=metrics.confusion_matrix(test_labels,pred)
print(cm)
plt.imshow(cm, cmap='binary')
这就是我的混淆矩阵的样子:
[[3962 325 0 ..., 0 0 0]
[ 250 2765 0 ..., 0 0 0]
[ 2 8 17 ..., 0 0 0]
...,
[ 1 6 0 ..., 5 0 0]
[ 1 1 0 ..., 0 0 0]
[ 9 0 0 ..., 0 0 9]]
但是,我没有收到清晰或清晰的情节。有一个更好的方法吗?
原文由 minks 发布,翻译遵循 CC BY-SA 4.0 许可协议
你可以使用
plt.matshow()
代替plt.imshow()
或者你可以使用seaborn模块的heatmap
( 见文档 矩阵)混淆