如何调用 torch.onnx.export 导出的模型?

https://pytorch.org/docs/stable/onnx.html?highlight=torch%20o...

图片.png

这里有一份 torch.onnx.export 的 demo 代码,运行后保存一个 onnx.pd 文件

import torch


class SumModule(torch.nn.Module):
    def forward(self, x):
        return torch.sum(x, dim=1)


torch.onnx.export(
    SumModule(),
    (torch.ones(2, 2),),
    "onnx.pb",
    input_names=["x"],
    output_names=["sum"]
)

我该如何使用这个 onnx.pb 文件呢?

我写了下面的 demo 代码:

import onnxruntime
import numpy
import torch


resnet_onnx = onnxruntime.InferenceSession(
    'onnx.pb')

x = torch.ones(2, 2)

inputs = {resnet_onnx.get_inputs()[0].name: x}
resnet_onnx.run(None, inputs)

然而报错了:

─➤  python -u "/Users/ponponon/Desktop/code/me/onnx_example/use.py"                                 1 ↵
Traceback (most recent call last):
  File "/Users/ponponon/Desktop/code/me/onnx_example/use.py", line 12, in <module>
    resnet_onnx.run(None, inputs)
  File "/Users/ponponon/.local/share/virtualenvs/onnx_example-mTitCDnh/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
RuntimeError: Input must be a list of dictionaries or a single numpy array for input 'x'.

我该如何解决呢?

阅读 2.6k
1 个回答

问题已解决:PyTorch模型的输入为tensor,而ONNX的输入为array

转成 onnx 之后,输入类型不应该是 torch.Tensor,而是 numpy.ndarray

所以,把调用的 demo 修改为如下就行调通:

import onnxruntime
import numpy
import torch


resnet_onnx = onnxruntime.InferenceSession(
    'onnx.pb')

# x = torch.ones(2, 2) 输入的类型不应该是 torch.Tensor
x = numpy.ones((2, 2), dtype=numpy.float32)  # 而是 numpy.ndarray


inputs = {resnet_onnx.get_inputs()[0].name: x}
print(resnet_onnx.run(None, inputs))

解决方案出处:

图片.png

9.1.3 使用ONNX Runtime进行推理


下面的 chatGPT 给的回答:

你可以使用 ONNX Runtime 来加载和执行保存在 onnx.pb 文件中的模型。

以下是一个使用 ONNX Runtime 加载并运行模型的示例代码:

import onnxruntime as ort
import numpy as np

# 加载模型
sess = ort.InferenceSession("onnx.pb")

# 准备输入数据
input_data = np.ones((2, 2)).astype("float32")

# 运行模型
output_data = sess.run(None, {"x": input_data})[0]

# 输出模型的预测结果
print(output_data)

在这个示例代码中,首先通过 ort.InferenceSession 加载 onnx.pb 文件中的模型,并使用 np.ones 函数生成了一个输入数据的示例,然后使用 sess.run 方法运行模型,并将输入数据传递给模型。最后,通过 print 输出模型的预测结果。

你可以根据自己的需要修改输入数据的形状和数值,以及输出结果的处理方式。

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题