如何在 PyTorch 中保存经过训练的模型?

新手上路,请多包涵
阅读 954
2 个回答

在他们的 github repo 上找到 了这个页面

保存模型的推荐方法

序列化和恢复模型有两种主要方法。

第一个(推荐)只保存和加载模型参数:

 torch.save(the_model.state_dict(), PATH)

然后后来:

 the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

 torch.save(the_model, PATH)

然后后来:

 the_model = torch.load(PATH)

但是在这种情况下,序列化的数据绑定到特定的类和使用的确切目录结构,因此在其他项目中使用时,或者经过一些严重的重构后,它可能会以各种方式中断。


另请参阅:官方 PyTorch 教程中的 保存和加载模型 部分。

原文由 dontloo 发布,翻译遵循 CC BY-SA 4.0 许可协议

这取决于你想做什么。

案例 #1:保存模型以供自己用于推理:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为您通常有 BatchNormDropout 层,它们在构造时默认处于训练模式:

 torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例#2:保存模型以便稍后恢复训练:如果您需要继续训练您将要保存的模型,那么您需要保存的不仅仅是模型。您还需要保存优化器的状态、时期、分数等。您可以这样做:

 state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

要恢复训练,您可以执行以下操作: state = torch.load(filepath) ,然后恢复每个单独对象的状态,如下所示:

 model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由于您正在恢复训练,一旦您在加载时恢复状态, 请勿 调用 model.eval()

案例 # 3:其他人无法访问您的代码而使用的模型:在 Tensorflow 中,您可以创建一个 .pb 文件来定义模型的架构和权重。这非常方便,特别是在使用 Tensorflow serve 时。在 Pytorch 中执行此操作的等效方法是:

 torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

这种方式仍然不是防弹的,并且由于 pytorch 仍在进行大量更改,因此我不推荐它。

原文由 Jadiel de Armas 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题