'torch.aten.mm' that was explicitly marked illegal
I am running torch mlir to tosa mlir conversion pipeline (command below), but keep seeing torch.aten.mm as illegal fn. Can someone help me figure out a pass for tosa conversion correctly?
torch-mlir-opt --torch-function-to-torch-backend-pipeline --torch-backend-to-tosa-backend-pipeline torch.mlir -o tosa.mlir
torch.mlir:7:10: error: failed to legalize operation 'torch.aten.mm' that was explicitly marked illegal %2 = torch.aten.mm %arg0, %1 : !torch.vtensor<[32,4096],f32>, !torch.vtensor<[4096,128256],bf16> -> !torch.vtensor<[32,128256],f32> ^ torch.mlir:7:10: note: see current operation: %8 = "torch.aten.mm"(%arg0, %7) : (!torch.vtensor<[32,4096],f32>, !torch.vtensor<[4096,128256],bf16>) -> !torch.vtensor<[32,128256],f32>
fx.export_and_import fails for this as well with tosa dialect yielding following error: /torch_mlir/torch_mlir/compiler_utils.py", line 127, in run_pipeline_with_repro_report raise TorchMlirCompilerError(trimmed_message) from None torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Torch Backend IR -> TOSA Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline
For Torch-MLIR developers, the error can be reproduced with: $ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/UnnammedModule.mlir
Hi, the element types of your input don't match each other, which is not allow as when lowering torch.aten.mm to TOSA.
thanks for the note @justin-ngo-arm