Numpy:将每行中的最大值更改为 1,将所有其他数字更改为 0

新手上路,请多包涵

我正在尝试实现一个 numpy 函数,该函数将 2D 数组每一行中的最大值替换为 1,将所有其他数字替换为零:

 >>> a = np.array([[0, 1],
...               [2, 3],
...               [4, 5],
...               [6, 7],
...               [9, 8]])
>>> b = some_function(a)
>>> b
[[0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [1. 0.]]

到目前为止我试过的

def some_function(x):
    a = np.zeros(x.shape)
    a[:,np.argmax(x, axis=1)] = 1
    return a

>>> b = some_function(a)
>>> b
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]

原文由 MikeRand 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1.9k
2 个回答

方法#1,调整你的:

 >>> a = np.array([[0, 1], [2, 3], [4, 5], [6, 7], [9, 8]])
>>> b = np.zeros_like(a)
>>> b[np.arange(len(a)), a.argmax(1)] = 1
>>> b
array([[0, 1],
       [0, 1],
       [0, 1],
       [0, 1],
       [1, 0]])

[实际上, range 可以正常工作;我写了 arange 出于习惯。]

方法#2,使用 max 代替 argmax 来处理多个元素达到最大值的情况:

 >>> a = np.array([[0, 1], [2, 2], [4, 3]])
>>> (a == a.max(axis=1)[:,None]).astype(int)
array([[0, 1],
       [1, 1],
       [1, 0]])

原文由 DSM 发布,翻译遵循 CC BY-SA 3.0 许可协议

我更喜欢像这样使用 numpy.where :

 a[np.where(a==np.max(a))] = 1

原文由 Cyclone 发布,翻译遵循 CC BY-SA 3.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题