我正在为多个标记数据绘制混淆矩阵,其中标签如下所示:
标签 1:1、0、0、0
标签 2:0、1、0、0
标签 3:0、0、1、0
标签 4:0、0、0、1
我能够使用以下代码成功分类。 我只需要一些帮助来绘制混淆矩阵。
for i in range(4):
y_train= y[:,i]
print('Train subject %d, class %s' % (subject, cols[i]))
lr.fit(X_train[::sample,:],y_train[::sample])
pred[:,i] = lr.predict_proba(X_test)[:,1]
我使用以下代码打印混淆矩阵,但它总是返回一个 2X2 矩阵
prediction = lr.predict(X_train)
print(confusion_matrix(y_train, prediction))
原文由 tourist 发布,翻译遵循 CC BY-SA 4.0 许可协议
我找到了一个可以绘制从
sklearn
生成的混淆矩阵的函数。它看起来像这样