pytorch 数据加载器多次迭代

新手上路,请多包涵

我使用 iris-dataset 通过 pytorch 训练一个简单的网络。

 trainset = iris.Iris(train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=150,
                                          shuffle=True, num_workers=2)

dataiter = iter(trainloader)

数据集本身只有 150 个数据点,由于批处理大小为 150,pytorch 数据加载器只对整个数据集迭代一次。

我现在的问题是,如果迭代完成后,通常有什么方法可以告诉 pytorch 的数据加载器重复数据集吗?

thnaks

更新

让它运行:)刚刚创建了一个数据加载器的子类并实现了我自己的 __next__()

原文由 arash javanmard 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 410
1 个回答

补充前面的答案。为了在数据集之间进行比较,通常最好使用总步数而不是总纪元数作为超参数。这是因为迭代次数不应该依赖于数据集的大小,而是依赖于它的复杂性。

我正在使用以下代码进行培训。它确保数据加载器每次重新启动时都会重新洗牌数据。

 # main training loop
    generator = iter(trainloader)
    for i in range(max_steps):

        try:
            # Samples the batch
            x, y = next(generator)
        except StopIteration:
            # restart the generator if the previous generator is exhausted.
            generator = iter(trainloader)
            x, y = next(generator)

我同意这不是最优雅的解决方案,但它使我不必依赖 epochs 进行训练。

原文由 Victor Zuanazzi 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题