如何将整数的 pytorch 张量转换为布尔值的张量?

新手上路,请多包涵

我想将整数张量转换为布尔张量。

具体来说,我希望能够拥有一个将 tensor([0,10,0,16]) 转换为 tensor([0,1,0,1]) 的函数

这在 Tensorflow 中很简单,只需使用 tf.cast(x,tf.bool)

我希望强制转换将所有大于 0 的整数更改为 1,将所有等于 0 的整数更改为 0。这在大多数语言中等同于 !!

由于 pytorch 似乎没有专用的布尔类型可以转换为,这里最好的方法是什么?

编辑:我正在寻找一种矢量化解决方案,而不是循环遍历每个元素。

原文由 Ross 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 2.4k
2 个回答

您正在寻找的是为给定的整数张量生成一个 布尔掩码。为此,您可以使用简单的比较运算符( > )或使用 torch.gt() 简单地检查条件:“张量中的值是否大于 0”,然后给出我们想要的结果。

 # input tensor
In [76]: t
Out[76]: tensor([ 0, 10,  0, 16])

# generate the needed boolean mask
In [78]: t > 0
Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)


 # sanity check
In [93]: mask = t > 0

In [94]: mask.type()
Out[94]: 'torch.ByteTensor'


注意:在 PyTorch 1.4+ 版本中,上述操作将返回 'torch.BoolTensor'

 In [9]: t > 0
Out[9]: tensor([False,  True, False,  True])

# alternatively, use `torch.gt()` API
In [11]: torch.gt(t, 0)
Out[11]: tensor([False,  True, False,  True])

如果您确实想要单个位( 0 s 或 1 s),请使用:

 In [14]: (t > 0).type(torch.uint8)
Out[14]: tensor([0, 1, 0, 1], dtype=torch.uint8)

# alternatively, use `torch.gt()` API
In [15]: torch.gt(t, 0).int()
Out[15]: tensor([0, 1, 0, 1], dtype=torch.int32)

此更改的原因已在此功能请求问题中进行了讨论: issues/4764 - Introduce torch.BoolTensor …


TL;DR : 简单的一个班轮

t.bool().int()

原文由 kmario23 发布,翻译遵循 CC BY-SA 4.0 许可协议

PyTorch 的 to(dtype) 方法具有方便 的数据类型命名别名。您可以简单地调用 bool

 >>> t.bool()
tensor([False,  True, False,  True])

 >>> t.bool().int()
tensor([0, 1, 0, 1], dtype=torch.int32)

原文由 iacob 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题