torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Failed to legalize operation 'torch.aten.scaled_dot_product_attention' (Self attention torch -> tosa conversion)

Open HemKava opened this issue 5 months ago • 5 comments

I am running Llama'3 attention layer's Torch mlir to TOSA mlir conversion pipeline (command below), but seeing torch.aten.scaled_dot_product_attention as illegal function. Can someone help me figure out a pass for tosa conversion correctly?

Batch Size : 1 Seq Length : 12 Hidden Size: 4096 Dimension (Head) 128 Dimension (Model) 4096

With this, I generated Torch MLIR for self attention layer with following state vectors: Hidden State, Shape: torch.Size([1, 12, 4096]), Dtype: torch.bfloat16 Position Embeddings - Cos, Shape: torch.Size([1, 12, 128]), Dtype: torch.bfloat16 Position Embeddings - Sin, Shape: torch.Size([1, 12, 128]), Dtype: torch.bfloat16 Attention Mask, Shape: torch.Size([1, 1, 1, 12]), Dtype: torch.bfloat16

Now, I am converting TorchMLIR to TOSA MLIR, like below:

torch-mlir-opt --torch-decompose-complex-ops --torch-function-to-torch-backend-pipeline --torch-backend-to-tosa-backend-pipeline layer0_self_attn.torch.mlir -o layer0_self_attn.tosa.mlir

Following is the error I get when converting scaled_dot_product_attention:

layer0_self_attn.torch.mlir:66:11: error: failed to legalize operation 'torch.aten.scaled_dot_product_attention' that was explicitly marked illegal %47 = torch.aten.scaled_dot_product_attention %28, %41, %46, %arg3, %float0.000000e00, %false, %float8.838830e-02, %false : !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,1,1,12],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool -> !torch.vtensor<[1,32,12,128],bf16> ^ layer0_self_attn.torch.mlir:66:11: note: see current operation: %164 = "torch.aten.scaled_dot_product_attention"(%112, %154, %163, %arg3, %6, %8, %5, %8) : (!torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,1,1,12],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool) -> !torch.vtensor<[1,32,12,128],bf16>

HemKava avatar Jul 21 '25 15:07 HemKava

Hi @HemKava, scaled_dot_product_attention doesn't have a legalization from Torch to TOSA yet, but you are welcomed to contribute one and add me or @sjarus as a reviewer.

justin-ngo-arm avatar Jul 28 '25 18:07 justin-ngo-arm

@justin-ngo-arm Thanks for responding. Do you know of any alternative pass to break down scaled_dot_product_attention prior to TOSA conversion?

HemKava avatar Jul 28 '25 19:07 HemKava

I'm not entirely sure, but I believe that there is none. Others might have better suggestions though.

justin-ngo-arm avatar Jul 28 '25 19:07 justin-ngo-arm

Hi @HemKava, I was looking at some older PR, and I found this one that was trying to add a decomposition for ScaledDotProductAttention a while back: https://github.com/llvm/torch-mlir/pull/3461. It was not yet accepted because of reasons described in that PR's comments, but I think it might be beneficial if you want to test it locally by cherry-picking.

justin-ngo-arm avatar Aug 04 '25 18:08 justin-ngo-arm

One option is to use PyTorch's decomposition to decompose the op into core ATEN ops before importing into MLIR. Decomposing early leads to successful importing with tosa IR:

import torch

class ScaledDotProductAttentionSameModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key, value):
        return torch.ops.aten.scaled_dot_product_attention(query, key, value)

m = ScaledDotProductAttentionSameModule()
x = torch.randn(1, 5, 5, dtype=torch.float32)
ep = torch.export.export(m, (x, x, x))

ep = ep .run_decompositions()
module = fx.export_and_import(
    ep, output_type="tosa"
)

sahas3 avatar Aug 12 '25 01:08 sahas3