TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] Importing `torchao` first breaks `torch_tensorrt.dynamo.compile` during `run_decompositions`

Open dgcnz opened this issue 1 year ago • 3 comments

Bug Description

Importing torchao before importing torch_tensorrt causes F.interpolate to fail during run_decompositions with:

AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernel

To Reproduce

import torchao
import torch_tensorrt
import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        _, _, h, w = x.shape
        z = F.interpolate(x, [h*2, w*2])
        return z

model = Model().cuda()
inputs = (torch.randn((2, 4, 8, 8)).cuda(),)

with torch.no_grad():
    ep = torch.export.export(
        model,
        args=inputs,
        strict=True
    )
    with torch_tensorrt.logging.debug():
        trt_gm = torch_tensorrt.dynamo.compile(
            ep,
            inputs,
            reuse_cached_engines=False,
            cache_built_engines=False,
            require_full_compilation=True,
            min_block_size=1,
        )  

Logs:

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/projects/scripts/mre/torchao_tensorrt_import.py", line 25, in <module>
    trt_gm = torch_tensorrt.dynamo.compile(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 228, in compile
    exported_program = exported_program.run_decompositions(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 116, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 1111, in run_decompositions
    return _decompose_exported_program(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 654, in _decompose_exported_program
    gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/export/exported_program.py", line 446, in _decompose_and_get_gm_with_new_signature_constants
    gm, graph_signature = aot_export_module(
                          ^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1262, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1497, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 625, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 194, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 859, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6495, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/fx/interpreter.py", line 275, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 64, in inner
    return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_higher_order_ops/utils.py", line 37, in autograd_not_implemented_inner
    result = operator(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
    return self.py_kernels[dk](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 3184, in upsample_nearest2d_vec
    return torch.ops.aten.upsample_nearest2d.default(input, osize)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 449, in __torch_dispatch__
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_ops.py", line 764, in decompose
    return self.py_kernels[dk](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dgcnz/.conda/envs/edge/lib/python3.12/site-packages/torch/_decomp/__init__.py", line 376, in _special_op_to_decompose_cia
    raise AssertionError(
AssertionError: Expected aten.upsample_nearest2d.default to have CompositeImplicitAutograd kernel

While executing %upsample_nearest2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_nearest2d.vec](args = (%x, [16, 16], None), kwargs = {})
Original traceback:
  File "/projects/scripts/mre/torchao_tensorrt_import.py", line 12, in forward
    z = F.interpolate(x, [h*2, w*2])

Expected behavior

The import order betweentorch_tensorrt and torchao should not matter.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.6.0.dev20241008+cu124
  • PyTorch Version (e.g. 1.0): 2.6.0.dev20241009+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): -
  • Are you using local sources or building from archives: -
  • Python version: 3.12.7
  • CUDA version: 12.4
  • GPU models and configuration: NVIDIA RTX 4060 TI
  • Any other relevant information:
  • torchao==0.6.0.dev20241009+cu124

Additional context

If you import torch_tensorrt first and then torchao the error disappears.

dgcnz avatar Oct 09 '24 19:10 dgcnz

From what I can take from the back trace seems like its a torchao/dyanmo issue but @peri044 and @lanluo-nvidia can you look at this as well as part of the quantization work?

narendasan avatar Oct 10 '24 17:10 narendasan

@dgcnz I have tried the example with the following setup:

(torch_tensorrt_release_py39) lanl@d5c2237-lcedt:~/git/py39/TensorRT$ python -m pip list | grep torch
pytorch-triton           3.1.0
torch                    2.5.0+cu124
torch_tensorrt           2.5.0+cu124
torchao                  0.6.0+cu124
torchprofile             0.0.4
torchvision              0.20.0+cu124

It is working no matter the import sequence. however it does fail on the latest main, which I believe is related to the decomposition changes introduced in torch 2.6.0: https://github.com/pytorch/pytorch/pull/135080 I will take a further look and get the corresponding changes in our torch_tensorrt.

lanluo-nvidia avatar Oct 10 '24 22:10 lanluo-nvidia

I think it's related to the changes I did at https://github.com/pytorch/TensorRT/pull/2790/files#diff-9cbf43cfb62cdf682eef77f6fd5cabc488bb5a91ff09bfee5d593827542b2476R3044 to not let PyTorch decompose those ops. I may find a way to fix it in one or two days.

HolyWu avatar Oct 11 '24 00:10 HolyWu