resnet50 导出成 tensorRT 报错 Unknown type bool encountered in graph lowering 怎么办?

resnet50 导出成 tensorRT 报错 Unknown type bool encountered in graph lowering 怎么办?

import torch_tensorrt
import torch.nn as nn
import torchvision.models as models
import torch
from loguru import logger

device = 'cpu'

torch_resnet50 = models.resnet50(pretrained=True)
torch_resnet50.to(device)



batch_size = 100
image_channel = 3
image_size = 224

input_data = torch.randn(batch_size, image_channel, image_size, image_size)
input_data=input_data.to(device)

# export_pytorch_script
# traced_model = torch.jit.trace(torch_resnet50, input_data)

# traced_model.save("model/torch_resnet50_torch_script.pt")
# logger.debug(f'export_pytorch_script done')

# export_pytorch_tensorRT


inputs = [
    torch_tensorrt.Input(
        min_shape=[batch_size, image_channel, image_size, image_size],
        opt_shape=[batch_size, image_channel, image_size, image_size],
        max_shape=[batch_size, image_channel, image_size, image_size],
    )
]
enabled_precisions = {torch.float}  # Run with fp16

trt_ts_module = torch_tensorrt.compile(
    torch_resnet50, inputs=inputs, enabled_precisions=enabled_precisions
)

# input_data = input_data.to(device)

torch.jit.save(trt_ts_module, "model/torch_resnet50_tensorRT.ts")
logger.debug(f'export_pytorch_tensorRT done')

代码参考的是: https://pytorch.org/TensorRT/getting_started/getting_started_...

运行报错:

Traceback (most recent call last):
  File "/home/ponponon/code/torch_example/torch_resnet50.py", line 39, in <module>
    trt_ts_module = torch_tensorrt.compile(
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 125, in compile
    return torch_tensorrt.ts.compile(
  File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py", line 136, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: Unknown type bool encountered in graph lowering. This type is not supported in ONNX export.

问题是,我这里哪来的 bool 类型呢?

阅读 1.9k
1 个回答

解决了,要把 restnet50 转成推理模式

torch_resnet50 = models.resnet50(pretrained=True)
torch_resnet50.eval()
torch_resnet50.to(device)
撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题