TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] Importing `torchao` first breaks `torch_tensorrt.dynamo.compile` during `run_decompositions`

Open dgcnz opened this issue 4 months ago • 3 comments

Bug Description

Importing torchao before importing torch_tensorrt causes F.interpolate to fail during run_decompositions with:

AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernel

To Reproduce

import torchao
import torch_tensorrt
import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        _, _, h, w = x.shape
        z = F.interpolate(x, [h*2, w*2])
        return z

model = Model().cuda()
inputs = (torch.randn((2, 4, 8, 8)).cuda(),)

with torch.no_grad():
    ep = torch.export.export(
        model,
        args=inputs,
        strict=True
    )
    with torch_tensorrt.logging.debug():
        trt_gm = torch_tensorrt.dynamo.compile(
            ep,
            inputs,
            reuse_cached_engines=False,
            cache_built_engines=False,
            require_full_compilation=True,
            min_block_size=1,
        )  

Logs:

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/projects/scripts/mre/torchao_tensorrt_import.py", line 25, in <module>
    trt_gm = torch_tensorrt.dynamo.compile(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 228, in compile
    exported_program = exported_program.run_decompositions(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 116, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 1111, in run_decompositions
    return _decompose_exported_program(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 654, in _decompose_exported_program
    gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 446, in _decompose_and_get_gm_with_new_signature_constants
    gm, graph_signature = aot_export_module(
                          ^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1262, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1497, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 625, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 194, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 859, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6495, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 275, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 64, in inner
    return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 37, in autograd_not_implemented_inner
    result = operator(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
    return self.py_kernels[dk](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 3184, in upsample_nearest2d_vec
    return torch.ops.aten.upsample_nearest2d.default(input, osize)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
    return self.py_kernels[dk](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_decomp/__init__.py", line 376, in _special_op_to_decompose_cia
    raise AssertionError(
AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernel

While executing %upsample_nearest2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_nearest2d.vec](args = (%x, [16, 16], None), kwargs = {})
Original traceback:
  File "/projects/scripts/mre/torchao_tensorrt_import.py", line 12, in forward
    z = F.interpolate(x, [h*2, w*2])

Expected behavior

The import order betweentorch_tensorrt and torchao should not matter.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.6.0.dev20241008+cu124
  • PyTorch Version (e.g. 1.0): 2.6.0.dev20241009+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): -
  • Are you using local sources or building from archives: -
  • Python version: 3.12.7
  • CUDA version: 12.4
  • GPU models and configuration: NVIDIA RTX 4060 TI
  • Any other relevant information:
  • torchao==0.6.0.dev20241009+cu124

Additional context

If you import torch_tensorrt first and then torchao the error disappears.

dgcnz avatar Oct 09 '24 19:10 dgcnz