如何替换 pytorch 的 transforms.Compose?

模型训练的时候,都是用 torchvision 的 transforms.Compose 预处理图片

例如下面这样:

preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

但是训练好了之后,需要部署上线了,这个时候,会把 pytorch 训练好的模型转成 onnx

这个时候,需要移除对 pytorch 的依赖

那么这个 transforms.Compose 怎么使用 numpy、PIL 等等之类的等价替代掉呢?


我写了下面的代码,但是我发现速度会比 transforms.Compose 慢 50%

不理解为什么会变慢

from PIL import Image
import numpy as np
from numpy import ndarray

def preprocess(image: Image.Image) -> ndarray:
    resized_image = image.resize((224, 224))
    resized_image_ndarray = np.array(resized_image)
    transposed_image_ndarray = resized_image_ndarray.transpose((2, 0, 1))
    transposed_image_ndarrayfloat32 = transposed_image_ndarray.astype(
        np.float32)
    transposed_image_ndarrayfloat32 /= 255.0
    mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
    std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
    normalized_image_ndarray = (transposed_image_ndarrayfloat32 - mean) / std
    normalized_image_ndarrayfloat32 = normalized_image_ndarray.astype(
        np.float32)
    return normalized_image_ndarrayfloat32

跑 3000 轮,耗时 9.727 秒

from torchvision import transforms
from PIL import Image
from torch import Tensor
from numpy import ndarray
import numpy


preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])
image = Image.open('bh.jpg')


for i in range(3000):
    tensor: Tensor = preprocess(image)

跑 3000 轮,耗时 16.093 秒

from PIL import Image
import numpy as np
from numpy import ndarray

def preprocess(image: Image.Image) -> ndarray:
    resized_image = image.resize((224, 224))
    resized_image_ndarray = np.array(resized_image)
    transposed_image_ndarray = resized_image_ndarray.transpose((2, 0, 1))
    transposed_image_ndarrayfloat32 = transposed_image_ndarray.astype(
        np.float32)
    transposed_image_ndarrayfloat32 /= 255.0
    mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
    std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
    normalized_image_ndarray = (transposed_image_ndarrayfloat32 - mean) / std
    normalized_image_ndarrayfloat32 = normalized_image_ndarray.astype(
        np.float32)
    return normalized_image_ndarrayfloat32


image = Image.open('bh.jpg')

for i in range(3000):
    preprocessed_ndarray: ndarray = preprocess(image)
阅读 2.3k
1 个回答
撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题