如何将下面的 resnet50 模型导出为 onnx 格式呢?

import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import torchvision.models as models

def gem(x: Tensor, p: int = 3, eps: float = 1e-6) -> Tensor:
    input = x.clamp(min=eps)
    _input = input.pow(p)
    kernel_size = (x.size(-2), x.size(-1))
    t = F.avg_pool2d(_input, kernel_size).pow(1./p)
    return t

def l2n(x: Tensor, eps: float = 1e-6) -> Tensor:

    return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x)

class L2N(nn.Module):

    def __init__(self, eps=1e-6):
        super(L2N, self).__init__()
        self.eps = eps

    def forward(self, x):
        return l2n(x, eps=self.eps)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'

class GeM(nn.Module):

    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return gem(x, p=self.p, eps=self.eps)

class ImageRetrievalNet(nn.Module):

    def __init__(self, dim: int = 512):
        super(ImageRetrievalNet, self).__init__()
        resnet50_model = models.resnet50()
        features = list(resnet50_model.children())[:-2]

        self.features = nn.Sequential(*features)

        self.lwhiten = None
        self.pool = GeM()
        self.whiten = nn.Linear(2048, dim, bias=True)
        self.norm = L2N()

    def forward(self, x: Tensor):
        o: Tensor = self.features(x)

        # features -> pool -> norm
        p = 3
        eps = 1e-6
        pooled_t = gem(o, p, eps)

        normed_t: Tensor = self.norm(pooled_t)
        o: Tensor = normed_t.squeeze(-1).squeeze(-1)

        # 启用白化,则: pooled features -> whiten -> norm
        if self.whiten is not None:
            whitened_t = self.whiten(o)
            normed_t: Tensor = self.norm(whitened_t)
            o = normed_t

        # 使每个图像为Dx1列向量(如果有许多图像,则为DxN)
        return o.permute(1, 0)

# 创建 PyTorch ResNet50 模型实例
model = ImageRetrievalNet()

# # 定义一个 PyTorch 张量来模拟输入数据
# batch_size = 4  # 定义批处理大小
# input_shape = (batch_size, 3, 224, 224)
# input_data = torch.randn(input_shape)

# # 将模型转换为 ONNX 格式
# output_path = "resnet50.onnx"
# torch.onnx.export(
#     model,
#     input_data,
#     output_path,
#     input_names=["input"], output_names=["output"],
#     dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
# )

batch_size = 4  # 定义批处理大小
input_shape = (batch_size, 3, 224, 224)
input_data = torch.randn(input_shape)

# 指定所有张量的静态形状
input_shape = ["batch_size", "channels", "height", "width"]
output_shape = ["batch_size", "features"]

    input_names=["input"], output_names=["output"],
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}



/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
Traceback (most recent call last):
  File "/home/ponponon/code/torch_example/resnet50_export_onnx copy.py", line 123, in <module>
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 504, in export
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 1529, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 1115, in _model_to_graph
    graph = _optimize_graph(
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 663, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/utils.py", line 1899, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 380, in wrapper
    return fn(g, *args, **kwargs)
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 286, in wrapper
    args = [
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 287, in <listcomp>
    _parse_arg(arg, arg_desc, arg_name, fn_name)  # type: ignore[assignment]
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 104, in _parse_arg
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Failed to export a node '%518 : Long(device=cpu) = onnx::Squeeze[axes=[0]](%517), scope: __main__.ImageRetrievalNet:: # /home/ponponon/code/torch_example/resnet50_export_onnx copy.py:12:0
' (in list node %527 : int[] = prim::ListConstruct(%518, %526), scope: __main__.ImageRetrievalNet::
) because it is not constant. Please try to make things (e.g. kernel sizes) static if possible.  [Caused by the value '527 defined in (%527 : int[] = prim::ListConstruct(%518, %526), scope: __main__.ImageRetrievalNet::
)' (type 'List[int]') in the TorchScript graph. The containing node has kind 'prim::ListConstruct'.] 

        #0: 518 defined in (%518 : Long(device=cpu) = onnx::Squeeze[axes=[0]](%517), scope: __main__.ImageRetrievalNet:: # /home/ponponon/code/torch_example/resnet50_export_onnx copy.py:12:0
    )  (type 'Tensor')
        #1: 526 defined in (%526 : Long(device=cpu) = onnx::Squeeze[axes=[0]](%525), scope: __main__.ImageRetrievalNet:: # /home/ponponon/code/torch_example/resnet50_export_onnx copy.py:12:0
    )  (type 'Tensor')
        #0: 527 defined in (%527 : int[] = prim::ListConstruct(%518, %526), scope: __main__.ImageRetrievalNet::
    )  (type 'List[int]')
阅读 3.6k
1 个回答

问题是由于 ONNX 导出过程中无法处理某些动态值。请试以下更改以解决此问题:从 ImageRetrievalNet 类中删除 self.lwhiten 属性,因为它没有被使用。
更改 GeM 类的 forward 方法,使其不再使用 Parameter 类型的属性 self.p,而是直接使用一个常量值。
以下是修改后的 ImageRetrievalNet 和 GeM 类:

class GeM(nn.Module):

    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = p
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return gem(x, p=self.p, eps=self.eps)

class ImageRetrievalNet(nn.Module):

    def __init__(self, dim: int = 512):
        super(ImageRetrievalNet, self).__init__()
        resnet50_model = models.resnet50()
        features = list(resnet50_model.children())[:-2]

        self.features = nn.Sequential(*features)
        self.pool = GeM()
        self.whiten = nn.Linear(2048, dim, bias=True)
        self.norm = L2N()

    def forward(self, x: Tensor):
        o: Tensor = self.features(x)

        # features -> pool -> norm
        pooled_t = self.pool(o)
        normed_t: Tensor = self.norm(pooled_t)
        o: Tensor = normed_t.squeeze(-1).squeeze(-1)

        # 启用白化,则: pooled features -> whiten -> norm
        if self.whiten is not None:
            whitened_t = self.whiten(o)
            normed_t: Tensor = self.norm(whitened_t)
            o = normed_t

        # 使每个图像为Dx1列向量(如果有许多图像,则为DxN)
        return o.permute(1, 0)

现在,您应该可以正常导出 ONNX 模型,如下所示:

# 创建 PyTorch ResNet50 模型实例
model = ImageRetrievalNet()

# 定义一个 PyTorch 张量来模拟输入数据
batch_size = 4  # 定义批处理大小
input_shape = (batch_size, 3, 224, 224)
input_data = torch.randn(input_shape)

# 将模型转换为 ONNX 格式
output_path = "resnet50.onnx"
    input_names=["input"], output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}

这些更改应解决在将模型导出为 ONNX 时遇到的错误。

