是否可以按降序使用 argsort?

新手上路,请多包涵

考虑以下代码:

 avgDists = np.array([1, 8, 6, 9, 4])
ids = avgDists.argsort()[:n]

这给了我 n 最小元素的索引。是否可以按降序使用同样的 argsort 来获得 n 最高元素的索引?

原文由 shn 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 333
2 个回答

如果你否定一个数组,最低的元素变成最高的元素,反之亦然。因此,索引 n 最高元素是:

 (-avgDists).argsort()[:n]

评论 中所述,对此进行推理的另一种方法是观察大元素在 argsort 中排在 _最后_。因此,您可以从 argsort 的尾部读取以找到 n 最高元素:

 avgDists.argsort()[::-1][:n]

这两种方法的时间复杂度都是 O(n log n) ,因为 argsort 调用是这里的主要术语。但是第二种方法有一个很好的优势:它用 O(1) 切片替换了数组的 O(n) 否定。如果你在循环中使用小数组,那么你可以通过避免这种否定来获得一些性能提升,如果你使用大数组,那么你可以节省内存使用量,因为否定会创建整个数组的副本。

请注意,这些方法并不总是给出相同的结果:如果请求稳定排序实现 argsort ,例如通过传递关键字参数 kind='mergesort' ,那么第一个策略将保持排序稳定性, 但第二种策略会破坏稳定性(即相同项目的位置将被颠倒)。

示例时间:

使用一个由 100 个浮点数和长度为 30 的尾部组成的小数组,视图方法快了大约 15%

 >>> avgDists = np.random.rand(100)
>>> n = 30
>>> timeit (-avgDists).argsort()[:n]
1.93 µs ± 6.68 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
>>> timeit avgDists.argsort()[::-1][:n]
1.64 µs ± 3.39 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
>>> timeit avgDists.argsort()[-n:][::-1]
1.64 µs ± 3.66 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

对于较大的数组,argsort 占主导地位,没有显着的时序差异

>>> avgDists = np.random.rand(1000)
>>> n = 300
>>> timeit (-avgDists).argsort()[:n]
21.9 µs ± 51.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
>>> timeit avgDists.argsort()[::-1][:n]
21.7 µs ± 33.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
>>> timeit avgDists.argsort()[-n:][::-1]
21.9 µs ± 37.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

请注意,下面 nedim 的评论 是不正确的。在反转之前或之后截断对效率没有影响,因为这两种操作只是以不同方式跨越数组视图,而不是实际复制数据。

原文由 wim 发布,翻译遵循 CC BY-SA 4.0 许可协议

就像 Python 一样, [::-1] 反转由 argsort()[:n] 返回的数组,给出最后 n 个元素:

 >>> avgDists=np.array([1, 8, 6, 9, 4])
>>> n=3
>>> ids = avgDists.argsort()[::-1][:n]
>>> ids
array([3, 1, 2])

这种方法的优点是 ids 是avgDists的一个 视图

 >>> ids.flags
  C_CONTIGUOUS : False
  F_CONTIGUOUS : False
  OWNDATA : False
  WRITEABLE : True
  ALIGNED : True
  UPDATEIFCOPY : False

(’OWNDATA’ 为 False 表示这是一个视图,而不是副本)

另一种方法是这样的:

 (-avgDists).argsort()[:n]

问题在于它的工作方式是为数组中的每个元素创建负数:

 >>> (-avgDists)
array([-1, -8, -6, -9, -4])

并为此创建一个副本:

 >>> (-avgDists_n).flags['OWNDATA']
True

所以如果你每次都用这个非常小的数据集计时:

 >>> import timeit
>>> timeit.timeit('(-avgDists).argsort()[:3]', setup="from __main__ import avgDists")
4.2879798610229045
>>> timeit.timeit('avgDists.argsort()[::-1][:3]', setup="from __main__ import avgDists")
2.8372560259886086

视图方法要快得多(并且使用 12 的内存……)

原文由 dawg 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题
logo
Stack Overflow 翻译
子站问答
访问
宣传栏