如何使用pytorch同时迭代两个数据加载器?

新手上路,请多包涵

我正在尝试实现一个包含两个图像的连体网络。我加载这些图像并创建两个单独的数据加载器。

在我的循环中,我想同时通过两个数据加载器,以便我可以在两个图像上训练网络。

 for i, data in enumerate(zip(dataloaders1, dataloaders2)):

    # get the inputs
    inputs1 = data[0][0].cuda(async=True);
    labels1 = data[0][1].cuda(async=True);

    inputs2 = data[1][0].cuda(async=True);
    labels2 = data[1][1].cuda(async=True);

    labels1 = labels1.view(batchSize,1)
    labels2 = labels2.view(batchSize,1)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs1 = alexnet(inputs1)
    outputs2 = alexnet(inputs2)

数据加载器的返回值是一个元组。但是,当我尝试使用 zip 迭代它们时,出现以下错误:

 OSError: [Errno 24] Too many open files
Exception NameError: "global name 'FileNotFoundError' is not defined" in <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7f2d3c00c190>> ignored

zip 不应该适用于所有可迭代的项目吗?但似乎在这里我不能在数据加载器上使用它。

还有其他方法可以追求这个吗?还是我错误地接近了 Siamese 网络的实施?

原文由 aa1 发布,翻译遵循 CC BY-SA 4.0 许可协议

阅读 788
2 个回答

我看到您正在努力实现正确的数据加载器功能。我会做:

 class Siamese(Dataset):

    def __init__(self, transform=None):

       #init data here

    def __len__(self):
        return   #length of the data

    def __getitem__(self, idx):
        #get images and labels here
        #returned images must be tensor
        #labels should be int
        return img1, img2 , label1, label2

原文由 macharya 发布,翻译遵循 CC BY-SA 4.0 许可协议

除了已经提到的内容之外, cycle()zip() 可能会造成内存泄漏问题- 特别是在使用图像数据集时!要解决这个问题,而不是像这样迭代:

 dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):

    for i, (data1, data2) in enumerate(zip(cycle(dataloaders1), dataloaders2)):

        do_cool_things()

你可以使用:

 dataloaders1 = DataLoader(DummyDataset(0, 100), batch_size=10, shuffle=True)
dataloaders2 = DataLoader(DummyDataset(0, 200), batch_size=10, shuffle=True)
num_epochs = 10

for epoch in range(num_epochs):
    dataloader_iterator = iter(dataloaders1)

    for i, data1 in enumerate(dataloaders2)):

        try:
            data2 = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloaders1)
            data2 = next(dataloader_iterator)

        do_cool_things()

Bear in mind that if you use labels as well, you should replace in this example data1 with (inputs1,targets1) and data2 with inputs2,targets2 , as @Sajad Norouzi 说。

感谢这个: https ://github.com/pytorch/pytorch/issues/1917#issuecomment-433698337

原文由 afroditi 发布,翻译遵循 CC BY-SA 4.0 许可协议

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题