在 Tensorflow 中,最后一个维度如何使用 tf.gather()?

新手上路,请多包涵

我正在尝试根据层之间的部分连接的最后一个维度来收集张量的切片。因为输出张量的形状是 [batch_size, h, w, depth] ,所以我想根据最后一个维度选择切片,例如

# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]

但是, tf.gather(L, [0, 2,3,8]) 似乎只适用于第一维(对吗?)谁能告诉我该怎么做?

原文由 YW P Kwon 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 540
2 个回答

这里有一个跟踪错误来支持这个用例: https ://github.com/tensorflow/tensorflow/issues/206

现在你可以:

  1. 转置您的矩阵,以便首先收集维度(转置很昂贵)

  2. 将你的张量重塑为 1d(重塑很便宜)并将你的收集列索引转换为线性索引的单个元素索引列表,然后重塑回来

  3. 使用 gather_nd 。仍然需要将您的列索引转换为单个元素索引的列表。

原文由 Yaroslav Bulatov 发布,翻译遵循 CC BY-SA 3.0 许可协议

推荐问题