AxisError:计算 AUC 时轴 1 超出维度 1 数组的范围

新手上路,请多包涵

我有一个分类问题,我有一个 8x8 图像的像素值和图像代表的数字,我的任务是使用 RandomForestClassifier 基于像素值预测数字(’Number’ 属性)。数值的值可以是0-9。

 from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score

forest_model = RandomForestClassifier(n_estimators=100, random_state=42)
forest_model.fit(train_df[input_var], train_df[target])
test_df['forest_pred'] = forest_model.predict_proba(test_df[input_var])[:,1]
roc_auc_score(test_df['Number'], test_df['forest_pred'], average = 'macro', multi_class="ovr")

在这里它抛出一个 AxisError。

追溯(最近一次通话):
  文件“dap_hazi_4.py”,第 44 行,位于
    roc_auc_score(test_df['Number'], test_df['forest_pred'], average = 'macro', multi_class="ovo")
  文件“/home/balint/.local/lib/python3.6/site-packages/sklearn/metrics/_ranking.py”,第 383 行,在 roc_auc_score
    多类、平均、样本权重)
  文件“/home/balint/.local/lib/python3.6/site-packages/sklearn/metrics/_ranking.py”,第 440 行,在 _multiclass_roc_auc_score
    如果不是 np.allclose(1, y_score.sum(axis=1)):
  文件“/home/balint/.local/lib/python3.6/site-packages/numpy/core/_methods.py”,第 38 行,_sum
    返回 umr_sum(a, axis, dtype, out, keepdims, initial, where)

AxisError:轴 1 超出了维度 1 数组的范围

原文由 Bálint Béres 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 772
2 个回答

该错误是由于您正在按照其他人的建议解决的多类问题。您需要做的就是预测概率,而不是预测类别。我之前遇到过同样的问题,这样做就解决了。

这是如何做到的 -

 # you might be predicting the class this way
pred = clf.predict(X_valid)

# change it to predict the probabilities which solves the AxisError problem.
pred_prob = clf.predict_proba(X_valid)
roc_auc_score(y_valid, pred_prob, multi_class='ovr')
0.8164900342274142

# shape before
pred.shape
(256,)
pred[:5]
array([1, 2, 1, 1, 2])

# shape after
pred_prob.shape
(256, 3)
pred_prob[:5]
array([[0.  , 1.  , 0.  ],
       [0.02, 0.12, 0.86],
       [0.  , 0.97, 0.03],
       [0.  , 0.8 , 0.2 ],
       [0.  , 0.42, 0.58]])

原文由 bhola prasad 发布,翻译遵循 CC BY-SA 4.0 许可协议

实际上,由于您的问题是多类的,因此标签必须是单热编码的。当标签是单热编码时,“multi_class”参数起作用。通过提供单热编码标签,您可以解决错误。

假设,您有 100 个测试标签和 5 个独特的类别,那么您的矩阵大小(测试标签的)必须是 (100,5) NOT (100,1)

原文由 Lalith Bharadwaj Baru 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进