.view() 在 PyTorch 中做什么?

新手上路,请多包涵

--- 对张量 x .view() 做了什么?负值是什么意思?

 x = x.view(-1, 16 * 5 * 5)

原文由 Wasi Ahmad 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 447
2 个回答

view() 在不复制内存的情况下重塑张量,类似于 numpy 的 reshape()

给定一个包含 16 个元素的张量 a

 import torch
a = torch.range(1, 16)

要重塑此张量使其成为 4 x 4 张量,请使用:

 a = a.view(4, 4)

现在 a 将是 4 x 4 张量。 请注意,重塑后元素的总数需要保持不变。将张量 a 重塑为 3 x 5 张量是不合适的。

参数-1是什么意思?

如果有任何情况你不知道你想要多少行但确定列数,那么你可以用 -1 指定它。 ( 请注意,您可以将其扩展到具有更多维度的张量。只有一个轴值可以是 -1 )。这是一种告诉图书馆的方式:“给我一个有这么多列的张量,你计算出实现这一目标所需的适当行数”。

这可以在 这个模型定义代码 中看到。在 forward 函数中的 x = self.pool(F.relu(self.conv2(x))) 行之后,您将拥有一个 16 深度的特征图。您必须将其展平以将其提供给完全连接的层。所以你告诉 PyTorch 重塑你获得的张量以具有特定的列数并告诉它自己决定行数。

原文由 Kashyap 发布,翻译遵循 CC BY-SA 4.0 许可协议

让我们做一些例子,从简单到困难。

  1. view 方法返回一个张量,其数据与 self 张量相同(这意味着返回的张量具有相同数量的元素),但具有不同的形状。例如:
    a = torch.arange(1, 17)  # a's shape is (16,)

   a.view(4, 4) # output below
     1   2   3   4
     5   6   7   8
     9  10  11  12
    13  14  15  16
   [torch.FloatTensor of size 4x4]

   a.view(2, 2, 4) # output below
   (0 ,.,.) =
   1   2   3   4
   5   6   7   8

   (1 ,.,.) =
    9  10  11  12
   13  14  15  16
   [torch.FloatTensor of size 2x2x4]

  1. 假设 -1 不是参数之一,当你将它们相乘时,结果必须等于张量中的元素数。如果您这样做: a.view(3, 3) ,它将引发 RuntimeError 因为形状 (3 x 3) 对于具有 16 个元素的输入无效。换句话说:3 x 3 不等于 16,而是 9。

  2. 您可以使用 -1 作为传递给函数的参数之一,但只能使用一次。所发生的只是该方法将为您计算如何填充该维度。例如 a.view(2, -1, 4) 等同于 a.view(2, 2, 4) 。 [16 / (2 x 4) = 2]

  3. 请注意,返回的张量 共享相同的数据。如果您在“视图”中进行更改,则您正在更改原始张量的数据:

    b = a.view(4, 4)
   b[0, 2] = 2
   a[2] == 3.0
   False

  1. 现在,对于更复杂的用例。文档说每个新视图维度必须是原始维度的子空间,或者仅跨度 d, d + 1, …, d + k 满足以下类似连续性的条件,对于所有 i = 0, . .., k - 1, stride[i] = stride[i + 1] x size[i + 1] 。否则,需要先调用 contiguous() 才能查看张量。例如:
    a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2)
   a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4)

   # The commented line below will raise a RuntimeError, because one dimension
   # spans across two contiguous subspaces
   # a_t.view(-1, 4)

   # instead do:
   a_t.contiguous().view(-1, 4)

   # To see why the first one does not work and the second does,
   # compare a.stride() and a_t.stride()
   a.stride() # (24, 6, 2, 1)
   a_t.stride() # (24, 2, 1, 6)

请注意,对于 a_tstride[0] != stride[1] x size[1] since 24 != 2 x 3

原文由 Jadiel de Armas 发布,翻译遵循 CC BY-SA 3.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题