如 nn.Sequential
中的输入
Model = nn.Sequential(x.view(x.shape[0],-1),
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim=1))
原文由 Khagendra 发布,翻译遵循 CC BY-SA 4.0 许可协议
您可以如下创建一个新模块/类,并在使用其他模块时按顺序使用它(调用
Flatten()
)。参考: https ://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983
编辑:
Flatten
现在是火炬的一部分。请参阅 https://pytorch.org/docs/stable/nn.html?highlight=flatten#torch.nn.Flatten