如何在 pytorch 中展平张量?

新手上路,请多包涵

给定一个多维张量,我如何将其展平以使其具有 单一 维度?

 torch.Size([2, 3, 5])    ⟶ flatten ⟶    torch.Size([30])

原文由 Tom Hale 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 933
2 个回答

长话短说: torch.flatten()

使用 torch.flatten() 在 v0.4.1 中引入并在 v1.0rc1记录

>  >>> t = torch.tensor([[[1, 2],
>                        [3, 4]],
>                       [[5, 6],
>                        [7, 8]]])
> >>> torch.flatten(t)
> tensor([1, 2, 3, 4, 5, 6, 7, 8])
> >>> torch.flatten(t, start_dim=1)
> tensor([[1, 2, 3, 4],
>         [5, 6, 7, 8]])
>
> ```

对于 v0.4.1 及更早版本,请使用 [`t.reshape(-1)`](https://pytorch.org/docs/master/torch.html#torch.reshape) 。

* * *

随着 `t.reshape(-1)` :

如果请求的视图在内存中是连续的,这将等同于 [`t.view(-1)`](https://pytorch.org/docs/master/tensors.html?highlight=view#torch.Tensor.view) 并且不会复制内存。

否则它将等同于 `t.` [`contiguous()`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.contiguous) `.view(-1)` 。

* * *

其他非选项:

- `t.view(-1)` [不会复制内存,但可能无法工作,具体取决于原始大小和步幅](https://pytorch.org/docs/master/tensors.html?highlight=view#torch.Tensor.view)

- `t.resize(-1)` 给出 `RuntimeError` (见下文)

- `t.resize(t.numel())` [关于低级方法的警告](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.resize_)(见下面的讨论)


(注意: `pytorch` `reshape()` [`reshape()` `numpy`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html)

* * *

`t.resize(t.numel())` 需要一些讨论。 [`torch.Tensor.resize_` 文档](https://pytorch.org/docs/stable/tensors.html#torch.Tensor.resize_) 说:

> 存储被重新解释为 C-contiguous,忽略当前步幅(除非目标大小等于当前大小,在这种情况下张量保持不变)

鉴于新的 `(1, numel())` 大小将忽略当前步幅,元素的顺序 _可能会_ 以与 `reshape(-1)` 不同的顺序出现。但是,“大小” _可能_ 意味着内存大小,而不是张量的大小。

如果 `t.resize(-1)` 既方便又高效,那就太好了,但是 `torch 1.0.1.post2` , `t = torch.rand([2, 3, 5]);  t.resize(-1)` 给出:

RuntimeError: requested resize to -1 (-1 elements in total), but the given tensor has a size of 2x2 (4 elements). autograd’s resize can only change the shape of a given tensor, while preserving the number of elements.

”`

在这里 提出了一个功能请求,但一致认为 resize() 是一种低级方法,应该优先使用 reshape()

原文由 Tom Hale 发布,翻译遵循 CC BY-SA 4.0 许可协议

使用 torch.reshape 并且只能传递一个维度来展平它。如果您不想对维度进行硬编码,只需指定 -1 即可推断出正确的维度。

 >>> x = torch.tensor([[1,2], [3,4]])
>>> x.reshape(-1)
tensor([1, 2, 3, 4])

编辑:

对于你的例子: 在此处输入图像描述

原文由 scarecrow 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题
logo
Stack Overflow 翻译
子站问答
访问
宣传栏