如何在 PyTorch 中保存经过训练的模型?我读过:
torch.save()
/torch.load()
用于保存/加载可序列化对象。model.state_dict()
/model.load_state_dict()
用于保存/加载模型状态。
原文由 Wasi Ahmad 发布,翻译遵循 CC BY-SA 4.0 许可协议
如何在 PyTorch 中保存经过训练的模型?我读过:
torch.save()
/ torch.load()
用于保存/加载可序列化对象。model.state_dict()
/ model.load_state_dict()
用于保存/加载模型状态。原文由 Wasi Ahmad 发布,翻译遵循 CC BY-SA 4.0 许可协议
这取决于你想做什么。
案例 #1:保存模型以供自己用于推理:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为您通常有 BatchNorm
和 Dropout
层,它们在构造时默认处于训练模式:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
案例#2:保存模型以便稍后恢复训练:如果您需要继续训练您将要保存的模型,那么您需要保存的不仅仅是模型。您还需要保存优化器的状态、时期、分数等。您可以这样做:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
要恢复训练,您可以执行以下操作: state = torch.load(filepath)
,然后恢复每个单独对象的状态,如下所示:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
由于您正在恢复训练,一旦您在加载时恢复状态, 请勿 调用 model.eval()
。
案例 # 3:其他人无法访问您的代码而使用的模型:在 Tensorflow 中,您可以创建一个 .pb
文件来定义模型的架构和权重。这非常方便,特别是在使用 Tensorflow serve
时。在 Pytorch 中执行此操作的等效方法是:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
这种方式仍然不是防弹的,并且由于 pytorch 仍在进行大量更改,因此我不推荐它。
原文由 Jadiel de Armas 发布,翻译遵循 CC BY-SA 4.0 许可协议
4 回答4.4k 阅读✓ 已解决
4 回答3.8k 阅读✓ 已解决
1 回答3k 阅读✓ 已解决
3 回答2.1k 阅读✓ 已解决
1 回答4.5k 阅读✓ 已解决
1 回答3.8k 阅读✓ 已解决
1 回答2.8k 阅读✓ 已解决
在他们的 github repo 上找到 了这个页面:
另请参阅:官方 PyTorch 教程中的 保存和加载模型 部分。