torch-mlir
torch-mlir copied to clipboard
Compile torch.nn.functional.scaled_dot_product_attention failed
Hi,
I'd like to compile projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py into torch-mlir, so I did the modification:
--- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py
+++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py
@@ -22,7 +22,8 @@ data = torch.randint(30522, (2, 128))
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"
module = torchscript.compile(
- model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True
+ model, data, output_type="linalg-on-tensors", use_tracing=True
)
But I encountered the following error:
python exception: Failure while executing pass pipeline:
error: "__module.bert/__module.bert.bert/__module.bert.bert.encoder/__module.bert.bert.encoder.layer.0/__module.bert.bert.encoder.layer.0.attention/__module.bert.bert.encoder.layer.0.attention.self/aten::scaled_dot_product_attention"("/scratch/honghsu/torch-mlir/mlir_venv/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py":435:0): failed to legalize operation 'torch.aten.scaled_dot_product_attention' that was explicitly marked illegal
note: "__module.bert/__module.bert.bert/__module.bert.bert.encoder/__module.bert.bert.encoder.layer.0/__module.bert.bert.encoder.layer.0.attention/__module.bert.bert.encoder.layer.0.attention.self/aten::scaled_dot_product_attention"("/scratch/honghsu/torch-mlir/mlir_venv/lib/python3.11/site-packages/transformers/models/bert/modeling_bert.py":435:0): see current operation: %111 = "torch.aten.scaled_dot_product_attention"(%100, %105, %110, %93, %48, %58, %57) : (!torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,2,128,64],f32>, !torch.vtensor<[2,1,128,128],f32>, !torch.float, !torch.bool, !torch.none) -> !torch.vtensor<[2,2,128,64],f32>
My environment:
torch-mlir: 20240608.126
transformers: 4.41.2
Have you ever encountered this or did I miss something?
Thanks.
Looks like only default inputs are supported, i'm not sure if there is a plan to support this
https://github.com/llvm/torch-mlir/blob/ca0e9066755b35c0889c6ab792265b0886325f50/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp#L1573-L1575
@rednoah91 , does it work for you with the original configuration? For me it doesn't. I had the same error with "linalg-to-tensor" and "TOSA", but even with original StableHLO it failed to convert with the error
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> StableHLO Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.constant.float'
note: see current operation: %84 = "torch.constant.float"() <{value = 0.000000e+00 : f64}> : () -> !torch.float
Did you have the same error?
@Hacker1337 yes I got the same error with output to StableHLO. Try to roll back the transformers version back to 4.40.0 works.
pip install transformers==4.40.0
But it just a workaround.
@rednoah91 did you find a solution to this problem? It seems that only the default parameters are supported. So no dropout, attention masking, causal attention masking, or scaling (https://github.com/llvm/torch-mlir/blob/main/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp#L1590).
Should TMTensor support these features, or should they be decomposed to other primitives in TorchToTMTensor.cpp?