TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

🐛 [Bug] nn.MultiheadAttention fails with Torch-TensorRT due to non-contiguous tensor before view()

Open LinzhouLi opened this issue 3 months ago • 1 comments

Bug Description

When compiling a simple nn.MultiheadAttention module with Torch-TensorRT using the dynamo IR, I get a runtime error related to view() because the tensor returned by scaled_dot_product_attention is not contiguous.

Adding .contiguous() fixes the problem.

To Reproduce

Code sample

import torch
import torch_tensorrt

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = torch.nn.MultiheadAttention(
            embed_dim=768, num_heads=8, kdim=768, vdim=768,
            dropout=0.0, bias=False, batch_first=True
        )
    
    def forward(self, x):
        att_out, _ = self.self_attn(x, x, x, need_weights=False)
        return att_out

model = Model()
model.cuda()
model.eval()

inputs = [torch.rand([1, 1024, 768], device='cuda', dtype=torch.float32)]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "tmp.ep", inputs=inputs)

Error

TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[09/17/2025-20:06:17] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] failed while attempting to run meta for aten.view.default
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] Traceback (most recent call last):
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]   File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]     r = func(*args, **kwargs)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]         ^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]   File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]     return self._op(*args, **kwargs)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]            ^^^^^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]   File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 4671, in view
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]     return _reshape_view_helper(a, *shape, allow_copy=False)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]   File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 3800, in _reshape_view_helper
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]     raise ValueError(msg)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] ValueError: Cannot view a tensor with shape torch.Size([1024, 1, 8, 96]) and strides (96, 786432, 98304, 1) as a tensor with shape (1024, 768)!
Traceback (most recent call last):
  File "/home/lilinzhou/code/Head/folder/issue.py", line 23, in <module>
    trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch_tensorrt/_compile.py", line 289, in compile
    trt_graph_module = dynamo_compile(
                       ^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 682, in compile
    exported_program = exported_program.run_decompositions(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 1405, in run_decompositions
    return _decompose_exported_program(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 872, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 491, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
                           ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/_trace.py", line 816, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1355, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1594, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 899, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7183, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 240, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 320, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 525, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1282, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1823, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1384, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 4671, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 3800, in _reshape_view_helper
    raise ValueError(msg)
ValueError: Cannot view a tensor with shape torch.Size([1024, 1, 8, 96]) and strides (96, 786432, 98304, 1) as a tensor with shape (1024, 768)!

While executing %view_6 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%permute, [1024, 768]), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x):
        x: "f32[1, 1024, 768][786432, 768, 1]"; 
    
        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        # No stacktrace found for following nodes
        self_attn_in_proj_weight: "f32[2304, 768][768, 1]" = self.self_attn.in_proj_weight
        self_attn_out_proj_weight: "f32[768, 768][768, 1]" = self.self_attn.out_proj.weight
        
         # File: /home/lilinzhou/code/Head/folder/issue.py:13 in forward, code: att_out, _ = self.self_attn(x, x, x, need_weights=False)
        transpose: "f32[1024, 1, 768][768, 786432, 1]" = torch.ops.aten.transpose.int(x, 1, 0);  x = None
        linear: "f32[1024, 1, 2304][2304, 2304, 1]" = torch.ops.aten.linear.default(transpose, self_attn_in_proj_weight);  transpose = self_attn_in_proj_weight = None
        unflatten: "f32[1024, 1, 3, 768][2304, 2304, 768, 1]" = torch.ops.aten.unflatten.int(linear, -1, [3, 768]);  linear = None
        unsqueeze: "f32[1, 1024, 1, 3, 768][2359296, 2304, 2304, 768, 1]" = torch.ops.aten.unsqueeze.default(unflatten, 0);  unflatten = None
        transpose_1: "f32[3, 1024, 1, 1, 768][768, 2304, 2304, 2359296, 1]" = torch.ops.aten.transpose.int(unsqueeze, 0, -2);  unsqueeze = None
        squeeze: "f32[3, 1024, 1, 768][768, 2304, 2304, 1]" = torch.ops.aten.squeeze.dim(transpose_1, -2);  transpose_1 = None
        contiguous: "f32[3, 1024, 1, 768][786432, 768, 768, 1]" = torch.ops.aten.contiguous.default(squeeze);  squeeze = None
        select: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 0)
        select_1: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 1)
        select_2: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 2);  contiguous = None
        view: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select, [1024, 8, 96]);  select = None
        transpose_2: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view, 0, 1);  view = None
        view_1: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select_1, [1024, 8, 96]);  select_1 = None
        transpose_3: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view_1, 0, 1);  view_1 = None
        view_2: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select_2, [1024, 8, 96]);  select_2 = None
        transpose_4: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view_2, 0, 1);  view_2 = None
        view_3: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_2, [1, 8, 1024, 96]);  transpose_2 = None
        view_4: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_3, [1, 8, 1024, 96]);  transpose_3 = None
        view_5: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_4, [1, 8, 1024, 96]);  transpose_4 = None
        scaled_dot_product_attention: "f32[1, 8, 1024, 96][786432, 96, 768, 1]" = torch.ops.aten.scaled_dot_product_attention.default(view_3, view_4, view_5);  view_3 = view_4 = view_5 = None
        permute: "f32[1024, 1, 8, 96][768, 786432, 96, 1]" = torch.ops.aten.permute.default(scaled_dot_product_attention, [2, 0, 1, 3]);  scaled_dot_product_attention = None
        view_6: "f32[1024, 768][768, 1]" = torch.ops.aten.view.default(permute, [1024, 768]);  permute = None
        linear_1: "f32[1024, 768][768, 1]" = torch.ops.aten.linear.default(view_6, self_attn_out_proj_weight);  view_6 = self_attn_out_proj_weight = None
        view_7: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.view.default(linear_1, [1024, 1, 768]);  linear_1 = None
        transpose_5: "f32[1, 1024, 768][768, 768, 1]" = torch.ops.aten.transpose.int(view_7, 1, 0);  view_7 = None
        return pytree.tree_unflatten((transpose_5,), self._out_spec)
        

Original traceback:
  File "/home/lilinzhou/code/Head/folder/issue.py", line 13, in forward
    att_out, _ = self.self_attn(x, x, x, need_weights=False)

Workaround

If I manually patch the code by adding .contiguous() right after scaled_dot_product_attention(), the problem goes away: https://github.com/pytorch/pytorch/blob/89a6dbe73af4ca64ee26f4e46219e163b827e698/torch/nn/functional.py#L6487-L6489

attn_output = scaled_dot_product_attention(
    q, k, v, attn_mask, dropout_p, is_causal
).contiguous()

Expected behavior

MultiheadAttention should work out of the box when compiled with Torch-TensorRT without requiring manual .contiguous() hacks.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.7.0+cu118
  • PyTorch Version (e.g. 1.0): 2.7.1+cu118
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives: building from archives
  • Python version: 3.12
  • CUDA version: 11.8
  • GPU models and configuration: RTX3090 (Driver Version 550.135)
  • Any other relevant information:

Additional context

It seems the tensor returned by scaled_dot_product_attention can be non-contiguous, but later .view() assumes contiguous memory layout. Adding .contiguous() fixes the issue.

LinzhouLi avatar Sep 17 '25 12:09 LinzhouLi

Thanks for reporting this issue @LinzhouLi It does seem like a fix in Pytorch would be the better solution here. Let me check on that. Here is a workaround for you. If you add the following code to the above script, it will disable SDPA decomposition and will move past this issue

_SDPA_OPS_TO_REMOVE = (
    torch.ops.aten.scaled_dot_product_attention.default,
    torch.ops.aten._scaled_dot_product_efficient_attention.default,
    torch.ops.aten._scaled_dot_product_flash_attention.default,
    torch.ops.aten._scaled_dot_product_cudnn_attention.default,
)


def _remove_decompositions():
    """
    Remove decompositions for SDPA operators.

    This function is idempotent. It ensures that the SDPA operators are removed
    from the decomposition table, allowing a custom converter to be used.
    """
    # Check if any of the decompositions still exist before proceeding
    if any(op in TORCH_TRT_DECOMPOSITIONS for op in _SDPA_OPS_TO_REMOVE):
        print("Removing SDPA decompositions to enable custom converter.")
        for op in _SDPA_OPS_TO_REMOVE:
            TORCH_TRT_DECOMPOSITIONS.pop(op, None)

_remove_decompositions()

peri044 avatar Nov 21 '25 23:11 peri044