Q:torchvision 的 transforms.ToTensor 干了什么事情?

A:torchvision 的 transforms.ToTensor() 是将 PIL Image 或 numpy.ndarray 转化成 torch.FloatTensor 数据类型的方法。这个方法的主要功能是:

  • 将 PIL Image 或 numpy.ndarray 转化成 torch.FloatTensor 数据类型。
  • 像素值在 0-255 范围内的 PIL Image 或 numpy.ndarray 在被转化成 torch.FloatTensor 时,会自动除以 255,使其像素值在 0-1 之间。
  • 如果像素值已经在 0-1 范围内,那么就不会自动除以 255。

简而言之,transforms.ToTensor() 的作用是将输入的 PIL Image 或 numpy.ndarray 转换为 Tensor 类型,并且进行标准化操作,方便神经网络的训练。因为神经网络需要的输入数据类型一般是 FloatTensor 类型,且需要进行标准化,这个过程常常使用 transforms.ToTensor() 方法来实现。


Q:torchvision 的 transforms.ToTensor() 处理过的数据,还能复原回一张图片吗?
A:可以通过 transforms.ToPILImage() 将处理后的数据转换为 PIL 图像格式,从而复原回一张图片。例如:

from torchvision import transforms
from PIL import Image

# 读取图片
image = Image.open('example.jpg')

# 对图片进行 transforms.ToTensor() 处理
preprocess = transforms.Compose([
    transforms.ToTensor()
])
tensor = preprocess(image)

# 将处理后的数据还原成一张图片
to_pil = transforms.ToPILImage()
image_restored = to_pil(tensor)
image_restored.show()

这段代码会将一张名为 example.jpg 的图片读入,经过 transforms.ToTensor() 处理后,将处理后的数据还原成一张图片并展示出来。


universe_king
3.4k 声望680 粉丝