使用 python 列表,我们可以这样做:
a = [1, 2, 3]
assert a.index(2) == 1
pytorch 张量如何直接找到 .index()
?
原文由 Han Bing 发布,翻译遵循 CC BY-SA 4.0 许可协议
使用 python 列表,我们可以这样做:
a = [1, 2, 3]
assert a.index(2) == 1
pytorch 张量如何直接找到 .index()
?
原文由 Han Bing 发布,翻译遵循 CC BY-SA 4.0 许可协议
对于多维张量,您可以执行以下操作:
(tensor == target_value).nonzero(as_tuple=True)
生成的张量的形状为 number_of_matches x tensor_dimension
。例如,假设 tensor
是一个 3 x 4
张量(这意味着维度为 2),结果将是一个二维张量,其中包含行中匹配项的索引。
tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
[0, 2],
[1, 2]])
原文由 dopexxx 发布,翻译遵循 CC BY-SA 4.0 许可协议
3 回答3.1k 阅读✓ 已解决
2 回答1.9k 阅读✓ 已解决
2 回答1.3k 阅读✓ 已解决
2 回答1.8k 阅读✓ 已解决
4 回答1.8k 阅读
3 回答1.7k 阅读
1 回答1.4k 阅读✓ 已解决
我认为没有从
list.index()
到 pytorch 函数的直接翻译。但是,您可以使用tensor==number
和nonzero()
函数获得类似的结果。例如:这段代码返回