https://pytorch.org/docs/stable/onnx.html?highlight=torch%20o...
这里有一份 torch.onnx.export 的 demo 代码,运行后保存一个 onnx.pd 文件
import torch
class SumModule(torch.nn.Module):
def forward(self, x):
return torch.sum(x, dim=1)
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"]
)
我该如何使用这个 onnx.pb 文件呢?
我写了下面的 demo 代码:
import onnxruntime
import numpy
import torch
resnet_onnx = onnxruntime.InferenceSession(
'onnx.pb')
x = torch.ones(2, 2)
inputs = {resnet_onnx.get_inputs()[0].name: x}
resnet_onnx.run(None, inputs)
然而报错了:
─➤ python -u "/Users/ponponon/Desktop/code/me/onnx_example/use.py" 1 ↵
Traceback (most recent call last):
File "/Users/ponponon/Desktop/code/me/onnx_example/use.py", line 12, in <module>
resnet_onnx.run(None, inputs)
File "/Users/ponponon/.local/share/virtualenvs/onnx_example-mTitCDnh/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
return self._sess.run(output_names, input_feed, run_options)
RuntimeError: Input must be a list of dictionaries or a single numpy array for input 'x'.
我该如何解决呢?
问题已解决:PyTorch模型的输入为tensor,而ONNX的输入为array
转成 onnx 之后,输入类型不应该是 torch.Tensor,而是 numpy.ndarray
所以,把调用的 demo 修改为如下就行调通:
解决方案出处:
9.1.3 使用ONNX Runtime进行推理
下面的 chatGPT 给的回答:
你可以使用 ONNX Runtime 来加载和执行保存在 onnx.pb 文件中的模型。
以下是一个使用 ONNX Runtime 加载并运行模型的示例代码:
在这个示例代码中,首先通过
ort.InferenceSession
加载 onnx.pb 文件中的模型,并使用np.ones
函数生成了一个输入数据的示例,然后使用sess.run
方法运行模型,并将输入数据传递给模型。最后,通过print
输出模型的预测结果。你可以根据自己的需要修改输入数据的形状和数值,以及输出结果的处理方式。