我正在尝试实现一个 numpy 函数,该函数将 2D 数组每一行中的最大值替换为 1,将所有其他数字替换为零:
>>> a = np.array([[0, 1],
... [2, 3],
... [4, 5],
... [6, 7],
... [9, 8]])
>>> b = some_function(a)
>>> b
[[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]
[1. 0.]]
到目前为止我试过的
def some_function(x):
a = np.zeros(x.shape)
a[:,np.argmax(x, axis=1)] = 1
return a
>>> b = some_function(a)
>>> b
[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]
原文由 MikeRand 发布,翻译遵循 CC BY-SA 4.0 许可协议
方法#1,调整你的:
[实际上,
range
可以正常工作;我写了arange
出于习惯。]方法#2,使用
max
代替argmax
来处理多个元素达到最大值的情况: