使用 pytorch 的时候,如何复用 DataLoader ,避免重复实例化 DataLoader?

import torch
from torch.utils.data import DataLoader, Dataset
from math import sqrt
from typing import List, Tuple, Union
from numpy import ndarray
from PIL import Image
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


class PreprocessImageDataset(Dataset):
    def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]):
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]

        image = Image.fromarray(image)

        preprocessed_image: torch.Tensor = preprocess(image)
        unsqueezed_image = preprocessed_image

        return unsqueezed_image


if __name__=='__main__':
    # 创建一些示例数据 python -m pysvddb.cli.main dna -o . --interval -5 --device mps --batch_size=1000 -i /Volumes/MyPassport/resnet/video/sample/AGanZhengZhuan_2.mp4
    data = list(range(10000000))


    batch_size = 10
    num_workers = 16

    dataset = PreprocessImageDataset(data)
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            num_workers=num_workers)

    # 在训练循环中迭代加载数据批次
    for batch_data in dataloader:
        batch_data
        print("Batch data:", batch_data)
        print("Batch data type :", type(batch_data))
        print("Batch data shape:", batch_data.shape)

每来一批 data,都需要 DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) ,重复创建进程池、销毁进程池

怎么复用 dataloader

阅读 3.4k
1 个回答
import torch
from torch.utils.data import DataLoader, Dataset
from math import sqrt
from typing import List, Tuple, Union
from numpy import ndarray
from PIL import Image
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


class PreprocessImageDataset(Dataset):
    def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]):
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]

        image = Image.fromarray(image)

        preprocessed_image: torch.Tensor = preprocess(image)
        unsqueezed_image = preprocessed_image

        return unsqueezed_image


if __name__=='__main__':
  
    data = list(range(10000000))

    batch_size = 10
    num_workers = 16

    dataset = PreprocessImageDataset(data)
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            num_workers=num_workers)

    for epoch in range(5):
        print(f"Epoch {epoch + 1}:")
        for batch_data in dataloader:
            batch_data
            print("Batch data:", batch_data)
            print("Batch data type :", type(batch_data))
            print("Batch data shape:", batch_data.shape)
撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题