Failed to legalize operation 'torch.aten.scaled_dot_product_attention' (Self attention torch -> tosa conversion)
I am running Llama'3 attention layer's Torch mlir to TOSA mlir conversion pipeline (command below), but seeing torch.aten.scaled_dot_product_attention as illegal function. Can someone help me figure out a pass for tosa conversion correctly?
Batch Size : 1 Seq Length : 12 Hidden Size: 4096 Dimension (Head) 128 Dimension (Model) 4096
With this, I generated Torch MLIR for self attention layer with following state vectors: Hidden State, Shape: torch.Size([1, 12, 4096]), Dtype: torch.bfloat16 Position Embeddings - Cos, Shape: torch.Size([1, 12, 128]), Dtype: torch.bfloat16 Position Embeddings - Sin, Shape: torch.Size([1, 12, 128]), Dtype: torch.bfloat16 Attention Mask, Shape: torch.Size([1, 1, 1, 12]), Dtype: torch.bfloat16
Now, I am converting TorchMLIR to TOSA MLIR, like below:
torch-mlir-opt --torch-decompose-complex-ops --torch-function-to-torch-backend-pipeline --torch-backend-to-tosa-backend-pipeline layer0_self_attn.torch.mlir -o layer0_self_attn.tosa.mlir
Following is the error I get when converting scaled_dot_product_attention:
layer0_self_attn.torch.mlir:66:11: error: failed to legalize operation 'torch.aten.scaled_dot_product_attention' that was explicitly marked illegal %47 = torch.aten.scaled_dot_product_attention %28, %41, %46, %arg3, %float0.000000e00, %false, %float8.838830e-02, %false : !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,1,1,12],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool -> !torch.vtensor<[1,32,12,128],bf16> ^ layer0_self_attn.torch.mlir:66:11: note: see current operation: %164 = "torch.aten.scaled_dot_product_attention"(%112, %154, %163, %arg3, %6, %8, %5, %8) : (!torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,1,1,12],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool) -> !torch.vtensor<[1,32,12,128],bf16>
Hi @HemKava, scaled_dot_product_attention doesn't have a legalization from Torch to TOSA yet, but you are welcomed to contribute one and add me or @sjarus as a reviewer.
@justin-ngo-arm Thanks for responding. Do you know of any alternative pass to break down scaled_dot_product_attention prior to TOSA conversion?
I'm not entirely sure, but I believe that there is none. Others might have better suggestions though.
Hi @HemKava, I was looking at some older PR, and I found this one that was trying to add a decomposition for ScaledDotProductAttention a while back: https://github.com/llvm/torch-mlir/pull/3461. It was not yet accepted because of reasons described in that PR's comments, but I think it might be beneficial if you want to test it locally by cherry-picking.
One option is to use PyTorch's decomposition to decompose the op into core ATEN ops before importing into MLIR. Decomposing early leads to successful importing with tosa IR:
import torch
class ScaledDotProductAttentionSameModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, query, key, value):
return torch.ops.aten.scaled_dot_product_attention(query, key, value)
m = ScaledDotProductAttentionSameModule()
x = torch.randn(1, 5, 5, dtype=torch.float32)
ep = torch.export.export(m, (x, x, x))
ep = ep .run_decompositions()
module = fx.export_and_import(
ep, output_type="tosa"
)