resnet50 导出成 tensorRT 报错 Unknown type bool encountered in graph lowering 怎么办?
import torch_tensorrt
import torch.nn as nn
import torchvision.models as models
import torch
from loguru import logger
device = 'cpu'
torch_resnet50 = models.resnet50(pretrained=True)
torch_resnet50.to(device)
batch_size = 100
image_channel = 3
image_size = 224
input_data = torch.randn(batch_size, image_channel, image_size, image_size)
input_data=input_data.to(device)
# export_pytorch_script
# traced_model = torch.jit.trace(torch_resnet50, input_data)
# traced_model.save("model/torch_resnet50_torch_script.pt")
# logger.debug(f'export_pytorch_script done')
# export_pytorch_tensorRT
inputs = [
torch_tensorrt.Input(
min_shape=[batch_size, image_channel, image_size, image_size],
opt_shape=[batch_size, image_channel, image_size, image_size],
max_shape=[batch_size, image_channel, image_size, image_size],
)
]
enabled_precisions = {torch.float} # Run with fp16
trt_ts_module = torch_tensorrt.compile(
torch_resnet50, inputs=inputs, enabled_precisions=enabled_precisions
)
# input_data = input_data.to(device)
torch.jit.save(trt_ts_module, "model/torch_resnet50_tensorRT.ts")
logger.debug(f'export_pytorch_tensorRT done')
代码参考的是: https://pytorch.org/TensorRT/getting_started/getting_started_...
运行报错:
Traceback (most recent call last):
File "/home/ponponon/code/torch_example/torch_resnet50.py", line 39, in <module>
trt_ts_module = torch_tensorrt.compile(
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 125, in compile
return torch_tensorrt.ts.compile(
File "/home/ponponon/.local/share/virtualenvs/torch_example-qg0YNkbt/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py", line 136, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: Unknown type bool encountered in graph lowering. This type is not supported in ONNX export.
问题是,我这里哪来的 bool 类型呢?
解决了,要把 restnet50 转成推理模式