TensorRT
TensorRT copied to clipboard
🐛 [Bug] Unable to export PyTorch transformer encoder using TensorRT
Bug Description
TensorRT fails to export PyTorch transformer encoder. Same is true for PyTorch transformer decoder, failing with the same error.
To Reproduce
Steps to reproduce the behavior:
- Download docker image nvcr.io/nvidia/pytorch:22.04-py3
- Inside the container run the code snippet supplied below
Code
import torch_tensorrt
import torch
from torch.nn.modules import TransformerEncoder
from torch.nn.modules import TransformerEncoderLayer as TorchTransformerEncoderLayer
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Debug)
layer = TorchTransformerEncoderLayer(
batch_first=False,
activation="relu",
d_model=256,
dim_feedforward=2048,
dropout=0.1,
nhead=8,
device="cuda"
)
encoder = TransformerEncoder(layer, 3)
scripted_module = torch.jit.script(encoder)
trt_ts_module = torch_tensorrt.compile(
scripted_module,
inputs=[
torch_tensorrt.Input(
shape=[256, 1, 256], dtype=torch.float
)
],
)
torch.jit.save(trt_ts_module, "tensorrt_minimal.ts")
Error message
/opt/conda/lib/python3.8/site-packages/torch/jit/_recursive.py:240: UserWarning: 'batch_first' was found in ScriptModule constants, but was not actually set in __init__. Consider removing it.
warnings.warn("'{}' was found in ScriptModule constants, "
INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
DEBUG: [Torch-TensorRT] - Settings requested for Lowering:
torch_executed_modules: [
]
Traceback (most recent call last):
File "tensorrt_minimal.py", line 22, in <module>
trt_ts_module = torch_tensorrt.compile(
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 115, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 116, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: Unknown type bool encountered in graph lowering. This type is not supported in ONNX export.
Expected behavior
TensorRT should be able to export Pytorch encoder without issues.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 1.1.0a0
- PyTorch Version (e.g. 1.0): 1.12.0a0+bd13bc6
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu 20.04.4
- How you installed PyTorch (
conda,pip,libtorch, source): conda - Build command you used (if compiling from source): /
- Are you using local sources or building from archives: /
- Python version: 3.8.13
- CUDA version: 11.6
- GPU models and configuration: NVIDIA RTX A6000
- Any other relevant information: using unmodified docker image nvcr.io/nvidia/pytorch:22.04-py3
Additional context
@peri044 did we root cause where the this ONNX export message comes from? Was it torch.jit.freeze?
I've tried setting the model in eval mode by scripted_module = torch.jit.script(encoder.eval()) and this issue no longer exists. However, I see a segfault after this log which is a different error and needs to be debugged
DEBUG: [Torch-TensorRT] - Pairing 0: src.1 : Input(shape: [256, 1, 256], dtype: Float32, format: NCHW\Contiguous\Linear)
Not seeing the segfault but seen an issue from the partitioning phase.
DEBUG: [Torch-TensorRT - Debug Build] - Registering input/output torch::jit::Value for segmented graphs
Traceback (most recent call last):
File "1165.py", line 23, in <module>
trt_ts_module = torch_tensorrt.compile(
File "/home/narens/Developer/opensource/pytorch_org/TensorRT/py/torch_tensorrt/_compile.py", line 113, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/home/narens/Developer/opensource/pytorch_org/TensorRT/py/torch_tensorrt/ts/_compiler.py", line 134, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:122] Expected to find type str for value why_not_fast_path.126 but get nothing.
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days