import torch
from torch.utils.data import DataLoader, Dataset
from math import sqrt
from typing import List, Tuple, Union
from numpy import ndarray
from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
class PreprocessImageDataset(Dataset):
def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]):
self.images = images
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
image = Image.fromarray(image)
preprocessed_image: torch.Tensor = preprocess(image)
unsqueezed_image = preprocessed_image
return unsqueezed_image
if __name__=='__main__':
# 创建一些示例数据 python -m pysvddb.cli.main dna -o . --interval -5 --device mps --batch_size=1000 -i /Volumes/MyPassport/resnet/video/sample/AGanZhengZhuan_2.mp4
data = list(range(10000000))
batch_size = 10
num_workers = 16
dataset = PreprocessImageDataset(data)
dataloader = DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers)
# 在训练循环中迭代加载数据批次
for batch_data in dataloader:
batch_data
print("Batch data:", batch_data)
print("Batch data type :", type(batch_data))
print("Batch data shape:", batch_data.shape)
每来一批 data,都需要 DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
,重复创建进程池、销毁进程池
怎么复用 dataloader