如何在 Pytorch 的 \`nn.Sequential\` 中压平输入

新手上路,请多包涵

nn.Sequential 中的输入

Model = nn.Sequential(x.view(x.shape[0],-1),
                     nn.Linear(784,256),
                     nn.ReLU(),
                     nn.Linear(256,128),
                     nn.ReLU(),
                     nn.Linear(128,64),
                     nn.ReLU(),
                     nn.Linear(64,10),
                     nn.LogSoftmax(dim=1))

原文由 Khagendra 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 472
2 个回答

您可以如下创建一个新模块/类,并在使用其他模块时按顺序使用它(调用 Flatten() )。

 class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

参考: https ://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983

编辑: Flatten 现在是火炬的一部分。请参阅 https://pytorch.org/docs/stable/nn.html?highlight=flatten#torch.nn.Flatten

原文由 Umang Gupta 发布,翻译遵循 CC BY-SA 4.0 许可协议

正如所定义的 flatten 方法

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

速度与 view() 相当,但 reshape 甚至更快。

 import torch.nn as nn

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

flatten = Flatten()

t = torch.Tensor(3,2,2).random_(0, 10)
print(t, t.shape)

#https://pytorch.org/docs/master/torch.html#torch.flatten
f = torch.flatten(t, start_dim=1, end_dim=-1)
print(f, f.shape)

#https://pytorch.org/docs/master/torch.html#torch.view
f = t.view(t.size(0), -1)
print(f, f.shape)

#https://pytorch.org/docs/master/torch.html#torch.reshape
f = t.reshape(t.size(0), -1)
print(f, f.shape)


测速

# flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

如果我们使用上面的类

flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
%timeit f=flatten(t)

5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

这个结果表明创建一个类的方法会更慢。这就是为什么在前向内压平张量更快的原因。我认为这是他们没有提升的主要原因 nn.Flatten

所以我的建议是使用内锋来提高速度。是这样的:

 out = inp.reshape(inp.size(0), -1)

原文由 prosti 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题