To Reproduce
import torch
import torch.nn as nn
import torch_tensorrt
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 2, 2)
self.beta = nn.Parameter(torch.ones((1, 3, 1, 1), dtype=torch.float))
def forward(self, x):
return self.conv(x) * self.beta
with torch.inference_mode():
model = MyModule().eval().cuda()
inputs = (torch.randn((1, 3, 224, 224), dtype=torch.float, device="cuda"),)
exported_program = torch.export.export(model, inputs)
trt_model = torch_tensorrt.dynamo.compile(
exported_program,
inputs,
enabled_precisions={torch.float},
debug=True,
min_block_size=1,
use_explicit_typing=True,
)
Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%p_beta : [num_users=1] = placeholder[target=p_beta]
%p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
%p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
%x : [num_users=1] = placeholder[target=x]
%conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias, [2, 2]), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%conv2d, %p_beta), kwargs = {})
return (mul,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%beta : [num_users=1] = get_attr[target=beta]
%conv_weight : [num_users=1] = get_attr[target=conv.weight]
%conv_bias : [num_users=1] = get_attr[target=conv.bias]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, %beta), kwargs = {})
return (mul,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
%beta : [num_users=1] = get_attr[target=beta]
%conv_weight : [num_users=1] = get_attr[target=conv.weight]
%conv_bias : [num_users=1] = get_attr[target=conv.bias]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, %beta), kwargs = {})
return (mul,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
%beta : [num_users=1] = get_attr[target=beta]
%conv_weight : [num_users=1] = get_attr[target=conv.weight]
%conv_bias : [num_users=1] = get_attr[target=conv.bias]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, %beta), kwargs = {})
return (mul,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
%beta : [num_users=1] = get_attr[target=beta]
%conv_weight : [num_users=1] = get_attr[target=conv.weight]
%conv_bias : [num_users=1] = get_attr[target=conv.bias]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, %beta), kwargs = {})
return (mul,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, make_refittable=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=True, use_fp32_acc=False)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.convolution.default + Operator Count: 1
- torch.ops.aten.mul.Tensor + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 2 operators out of 2 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.convolution.default + Operator Count: 1
- torch.ops.aten.mul.Tensor + Operator Count: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
Input shapes: [(1, 3, 224, 224)]
graph():
%x : [num_users=1] = placeholder[target=x]
%conv_weight : [num_users=1] = get_attr[target=conv.weight]
%conv_bias : [num_users=1] = get_attr[target=conv.bias]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv_weight, %conv_bias, [2, 2], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
%beta : [num_users=1] = get_attr[target=beta]
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, %beta), kwargs = {})
return mul
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node conv_weight (kind: conv.weight, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv_weight [conv.weight] (Inputs: () | Outputs: (conv_weight: (3, 3, 2, 2)@float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node conv_bias (kind: conv.bias, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv_bias [conv.bias] (Inputs: () | Outputs: (conv_bias: (3,)@float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node conv/convolution (kind: aten.convolution.default, args: ('x <Node>', 'conv_weight <Node>', 'conv_bias <Node>', ['2 <int>', '2 <int>'], ['0 <int>', '0 <int>'], ['1 <int>', '1 <int>'], 'False <bool>', ['0 <int>', '0 <int>'], '1 <int>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node conv/convolution [aten.convolution.default] (Inputs: (x: (1, 3, 224, 224)@torch.float32, conv_weight: (3, 3, 2, 2)@float32, conv_bias: (3,)@float32, [2, 2], [0, 0], [1, 1], False, [0, 0], 1) | Outputs: (convolution: (1, 3, 112, 112)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node beta (kind: beta, args: ())
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node beta [beta] (Inputs: () | Outputs: (beta: (1, 3, 1, 1)@float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /mul (kind: aten.mul.Tensor, args: ('convolution <Node>', 'beta <Node>'))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /mul [aten.mul.Tensor] (Inputs: (convolution: (1, 3, 112, 112)@torch.float32, beta: (1, 3, 1, 1)@float32) | Outputs: (mul: (1, 3, 112, 112)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('mul <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 3, 112, 112), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (mul: (1, 3, 112, 112)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.002087
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
ERROR:torch_tensorrt [TensorRT Conversion Context]:[graphOptimizerDetails.cpp::matchTypeSpec::128] Error Code 2: Internal Error (Assertion first.outputs[0] == second.inputs[0] failed. )
Traceback (most recent call last):
File "/home/holywu/test.py", line 22, in <module>
trt_model = torch_tensorrt.dynamo.compile(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 308, in compile
trt_gm = compile_module(
^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 482, in compile_module
trt_module = convert_module(
^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 141, in convert_module
interpreter_result = interpret_module_to_result(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 120, in interpret_module_to_result
interpreter_result = interpreter.run()
^^^^^^^^^^^^^^^^^
File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 630, in run
assert serialized_engine
^^^^^^^^^^^^^^^^^
AssertionError
Environment
- Torch-TensorRT Version (e.g. 1.0.0): 2.6.0.dev20241011+cu124
- PyTorch Version (e.g. 1.0): 2.6.0.dev20241011+cu124
- CPU Architecture: x64
- OS (e.g., Linux): Ubuntu 24.04.1
- How you installed PyTorch (
conda, pip, libtorch, source): pip
- Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.12.3
- CUDA version: 12.4
- GPU models and configuration: RTX 4060 Ti
- Any other relevant information: