[DRAFT][TorchToLinalg] Implement lowering of torch.aten.conj_physical
I've attempted to implement torch.aten.conj_physical, however I've had issues lowering the operator. My lowering works in all the trivial situations where it is essentially a no-op, for example taking a conjugate of a real number, or an integer. However, when I have to lower conj_physical into the MLIR operator complex.conj, I seem to be running into an issue, this is the test (in the commit) which fails:
class ElementwiseConjPhysicalComplexModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1], torch.complex64, True),
]
)
def forward(self, a):
return torch.ops.aten.conj_physical(a)
@register_test_case(module_factory=lambda: ElementwiseConjPhysicalComplexModule())
def ElementwiseConjPhysicalComplexModule_basic(module, tu: TestUtils):
module.forward(torch.view_as_complex(tu.rand(3, 4, 2)))
This is the output:
TORCH_VERSION_FOR_COMPARISON = 2.5.0.dev20240718
Running tests sequentially with progress status
*** RUNNING TEST: ElementwiseConjPhysicalComplexModule_basic ***
Compiling ElementwiseConjPhysicalComplexModule_basic...
====================
TorchScript RAW IR
module attributes {torch.debug_module_name = "ElementwiseConjPhysicalComplexModule"} {
func.func private @__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule.forward(%arg0: !torch.nn.Module<"__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,?],complex<f32>>}) -> !torch.tensor {
%1 = torch.aten.conj_physical %arg1 : !torch.tensor -> !torch.tensor
return %1 : !torch.tensor
}
torch.class_type @__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule {
torch.attr private "training" : !torch.bool
torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
torch.method "forward", @__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule.forward
}
%true = torch.constant.bool true
%none = torch.constant.none
%0 = torch.nn_module {
torch.slot "training", %true : !torch.bool
torch.slot "_is_full_backward_hook", %none : !torch.none
} : !torch.nn.Module<"__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule">
}
====================
Torch Backend IR
module attributes {torch.debug_module_name = "ElementwiseConjPhysicalComplexModule"} {
func.func @forward(%arg0: !torch.vtensor<[?,?],complex<f32>>) -> !torch.vtensor<[?,?],complex<f32>> {
%0 = torch.aten.conj_physical %arg0 : !torch.vtensor<[?,?],complex<f32>> -> !torch.vtensor<[?,?],complex<f32>>
return %0 : !torch.vtensor<[?,?],complex<f32>>
}
}
====================
LINALG Backend IR
module attributes {torch.debug_module_name = "ElementwiseConjPhysicalComplexModule"} {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @forward(%arg0: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xcomplex<f32>>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xcomplex<f32>>
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xcomplex<f32>>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xcomplex<f32>>) outs(%0 : tensor<?x?xcomplex<f32>>) {
^bb0(%in: complex<f32>, %out: complex<f32>):
%2 = complex.conj %in : complex<f32>
linalg.yield %2 : complex<f32>
} -> tensor<?x?xcomplex<f32>>
return %1 : tensor<?x?xcomplex<f32>>
}
}
TORCH_VERSION_FOR_COMPARISON = 2.5.0.dev20240718
FAIL - "ElementwiseConjPhysicalComplexModule_basic"
Unexpected outcome summary: (linalg)
****** Failed tests - 1 tests
FAIL - "ElementwiseConjPhysicalComplexModule_basic"
Compilation error: Traceback (most recent call last):
File "/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 313, in compile_and_run_test
compiled = config.compile(test.program_factory(), verbose=verbose)
File "/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py", line 38, in compile
return self.backend.compile(module)
File "/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py", line 227, in compile
run_pipeline_with_repro_report(
File "/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 78, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Linalg-on-Tensors IR to LLVM with RefBackend failed with the following diagnostics:
error: failed to legalize operation 'complex.conj' that was explicitly marked illegal
note: see current operation: %54 = "complex.conj"(%53) <{fastmath = #arith.fastmath<none>}> : (complex<f32>) -> complex<f32>
python exception: Failure while executing pass pipeline
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(func.func(linalg-generalize-named-ops),func.func(linalg-fuse-elementwise-ops),convert-shape-to-std,sparse-assembler{direct-out},sparsification-and-bufferization,sparse-storage-specifier-to-llvm,func.func(expand-realloc),func.func(refback-generalize-tensor-pad),func.func(refback-generalize-tensor-concat),func.func(tm-tensor-bufferize),one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map},refback-mlprogram-bufferize,func.func(finalizing-bufferize),func.func(buffer-deallocation),inline,refback-munge-calling-conventions,func.func(tm-tensor-to-loops),func.func(refback-munge-memref-copy),func.func(convert-linalg-to-loops),func.func(lower-affine),convert-scf-to-cf,func.func(refback-expand-ops-for-llvm),func.func(arith-expand),func.func(convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,lower-affine,convert-bufferization-to-memref,finalize-memref-to-llvm,func.func(convert-arith-to-llvm),convert-vector-to-llvm,convert-func-to-llvm,convert-cf-to-llvm,convert-complex-to-llvm,reconcile-unrealized-casts)' /tmp/ElementwiseConjPhysicalComplexModule.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
Summary:
Failed: 1
When I try to run the long command right above to reproduce the error, I get a different error, related to the pass arguments in the command itself:
<unknown>:0: error: MLIR Textual PassPipeline Parser:1:11: error: 'linalg-generalize-named-ops' does not refer to a registered pass or pass pipeline
I've discovered I need to run this command to recreate the issue:
</tmp/ElementwiseConjPhysicalComplexModule.mlir mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(linalg-generalize-named-ops),func.func(linalg-fuse-elementwise-ops),convert-shape-to-std,sparse-assembler{direct-out},sparsification-and-bufferization,sparse-storage-specifier-to-llvm,func.func(expand-realloc))' | torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(refback-generalize-tensor-pad),func.func(refback-generalize-tensor-concat),func.func(tm-tensor-bufferize))' | mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map})' | torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(refback-mlprogram-bufferize)' | mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(finalizing-bufferize),func.func(buffer-deallocation),inline)' | torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(refback-munge-calling-conventions,func.func(tm-tensor-to-loops),func.func(refback-munge-memref-copy))' | mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(convert-linalg-to-loops),func.func(lower-affine),convert-scf-to-cf)' | torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(refback-expand-ops-for-llvm))' | mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(arith-expand),func.func(convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,lower-affine,convert-bufferization-to-memref,finalize-memref-to-llvm,func.func(convert-arith-to-llvm),convert-vector-to-llvm,convert-func-to-llvm,convert-cf-to-llvm,convert-complex-to-llvm,reconcile-unrealized-casts)'
This is the input file to the command:
#loc = loc(unknown)
#loc1 = loc("/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/elementwise.py":4848:15)
#map = affine_map<(d0, d1) -> (d0, d1)>
#loc2 = loc("aten::conj_physical"(#loc1))
module attributes {torch.debug_module_name = "ElementwiseConjPhysicalComplexModule"} {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64> loc(#loc)
func.func @forward(%arg0: tensor<?x?xcomplex<f32>> loc(unknown)) -> tensor<?x?xcomplex<f32>> {
%c1 = arith.constant 1 : index loc(#loc)
%c0 = arith.constant 0 : index loc(#loc)
%dim = tensor.dim %arg0, %c0 : tensor<?x?xcomplex<f32>> loc(#loc2)
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xcomplex<f32>> loc(#loc2)
%0 = tensor.empty(%dim, %dim_0) : tensor<?x?xcomplex<f32>> loc(#loc2)
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xcomplex<f32>>) outs(%0 : tensor<?x?xcomplex<f32>>) {
^bb0(%in: complex<f32> loc("aten::conj_physical"(#loc1)), %out: complex<f32> loc("aten::conj_physical"(#loc1))):
%2 = complex.conj %in : complex<f32> loc(#loc2)
linalg.yield %2 : complex<f32> loc(#loc2)
} -> tensor<?x?xcomplex<f32>> loc(#loc2)
return %1 : tensor<?x?xcomplex<f32>> loc(#loc)
} loc(#loc)
} loc(#loc)
This is the (very long) output: output.txt
The convert-complex-to-llvm pass doesn't seem to have a pattern for complex.conj, you might have to run it through convert-complex-to-standard first.
@ubfx I have tried doing this. I found the definition for the lowering pipeline, and added the marked line. The test passes now. I have also run the build_tools/ci/test_posix, and apparently all the other linalg tests still work with this pass.
Should this pass then be added to the Torch-MLIR project? And if so, should I add it into this PR?
LOWERING_PIPELINE = (
"builtin.module("
+ ",".join(
[
# Apply some optimizations. It would be great if MLIR had more useful
# optimizations that worked out of the box here.
# Note: When measured, this doesn't seem to actually help that much
# for the linalg-on-tensors backend.
# This is likely because if things are naturally fusable we usually already
# emit things in that form from the high level (e.g. single linalg-generic).
# Other backends are likely to benefit more.
"func.func(linalg-generalize-named-ops)",
"func.func(linalg-fuse-elementwise-ops)",
"convert-shape-to-std",
# MLIR Sparsifier mini-pipeline. Note that this is the bare minimum
# to ensure operations on sparse tensors are lowered to loops.
"sparse-assembler{direct-out}",
"sparsification-and-bufferization",
"sparse-storage-specifier-to-llvm",
# Buffer deallocation pass does not know how to handle realloc.
"func.func(expand-realloc)",
# Generalize pad and concat after sparse compiler, as they are handled
# differently when the operations involve sparse operand.
"func.func(refback-generalize-tensor-pad)",
"func.func(refback-generalize-tensor-concat)",
# Bufferize.
"func.func(tm-tensor-bufferize)",
"one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}",
"refback-mlprogram-bufferize",
"func.func(finalizing-bufferize)",
"func.func(buffer-deallocation)",
# Buffer-deallocation does not work with the inlined code generated
# by sparse tensor dialect.
"inline", # inline sparse helper methods where useful
# Munge to make it ExecutionEngine compatible.
# Specifically, we rewrite calling convention boundaries to be in terms
# of unranked memref, and we rewrite the return to actually be a
# callback that consumes the return (the final munged function always
# returns void at the C level -- we get the return value by providing the
# callback).
"refback-munge-calling-conventions",
# Insert global variable and instruction sequence for getting the next
# global seed used in stateful rng.
# Lower to LLVM
"func.func(tm-tensor-to-loops)",
"func.func(refback-munge-memref-copy)",
"func.func(convert-linalg-to-loops)",
"func.func(lower-affine)",
# THE LINE BELOW HAS BEEN ADDED
"convert-complex-to-standard",
"convert-scf-to-cf",
"func.func(refback-expand-ops-for-llvm)",
"func.func(arith-expand)",
"func.func(convert-math-to-llvm)",
# Handle some complex mlir::math ops (e.g. atan2)
"convert-math-to-libm",
"expand-strided-metadata",
"finalize-memref-to-llvm",
"lower-affine",
"convert-bufferization-to-memref",
"finalize-memref-to-llvm",
"func.func(convert-arith-to-llvm)",
"convert-vector-to-llvm",
"convert-func-to-llvm",
"convert-cf-to-llvm",
"convert-complex-to-llvm",
"reconcile-unrealized-casts",
]
)
+ ")"
)
Did you experiment with the placement of the pass within the pipeline? There might be advantages to placing it a bit further back in the pipeline.
Generally, I can't see anything wrong with adding an upstream MLIR pass to the pipeline. I would add it to this PR since it is required to make the complex.conj lowering work. The reviewers can chip in in case there are specific considerations about changing the pipeline.
I did, at first I placed it close to the and, before "convert-complex-to-llvm", however it seems that's a bit too late, so I placed it further back, I'm not sure what the optimal position would be