我正在尝试根据层之间的部分连接的最后一个维度来收集张量的切片。因为输出张量的形状是 [batch_size, h, w, depth]
,所以我想根据最后一个维度选择切片,例如
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
但是, tf.gather(L, [0, 2,3,8])
似乎只适用于第一维(对吗?)谁能告诉我该怎么做?
原文由 YW P Kwon 发布,翻译遵循 CC BY-SA 4.0 许可协议
我正在尝试根据层之间的部分连接的最后一个维度来收集张量的切片。因为输出张量的形状是 [batch_size, h, w, depth]
,所以我想根据最后一个维度选择切片,例如
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
但是, tf.gather(L, [0, 2,3,8])
似乎只适用于第一维(对吗?)谁能告诉我该怎么做?
原文由 YW P Kwon 发布,翻译遵循 CC BY-SA 4.0 许可协议
从 TensorFlow 1.3 tf.gather
开始,有一个 axis
参数,因此不再需要此处的各种解决方法。
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223
原文由 rryan 发布,翻译遵循 CC BY-SA 3.0 许可协议
2 回答5.2k 阅读✓ 已解决
2 回答1.1k 阅读✓ 已解决
4 回答1.4k 阅读✓ 已解决
3 回答1.3k 阅读✓ 已解决
3 回答1.3k 阅读✓ 已解决
2 回答895 阅读✓ 已解决
1 回答1.8k 阅读✓ 已解决
这里有一个跟踪错误来支持这个用例: https ://github.com/tensorflow/tensorflow/issues/206
现在你可以:
转置您的矩阵,以便首先收集维度(转置很昂贵)
将你的张量重塑为 1d(重塑很便宜)并将你的收集列索引转换为线性索引的单个元素索引列表,然后重塑回来
使用
gather_nd
。仍然需要将您的列索引转换为单个元素索引的列表。