Pytorch Tensor如何获取特定值的索引

新手上路,请多包涵

使用 python 列表,我们可以这样做:

 a = [1, 2, 3]
assert a.index(2) == 1

pytorch 张量如何直接找到 .index()

原文由 Han Bing 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 1k
2 个回答

我认为没有从 list.index() 到 pytorch 函数的直接翻译。但是,您可以使用 tensor==numbernonzero() 函数获得类似的结果。例如:

 t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero(as_tuple=True)[0])

这段代码返回

1个

[大小为 1x1 的 torch.LongTensor]

原文由 Manuel Lagunas 发布,翻译遵循 CC BY-SA 4.0 许可协议

对于多维张量,您可以执行以下操作:

 (tensor == target_value).nonzero(as_tuple=True)

生成的张量的形状为 number_of_matches x tensor_dimension 。例如,假设 tensor 是一个 3 x 4 张量(这意味着维度为 2),结果将是一个二维张量,其中包含行中匹配项的索引。

 tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
        [0, 2],
        [1, 2]])

原文由 dopexxx 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题
logo
Stack Overflow 翻译
子站问答
访问
宣传栏