torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[DRAFT][TorchToLinalg] Implement lowering of torch.aten.conj_physical

Open pkapris-syrmia opened this issue 1 year ago • 5 comments

I've attempted to implement torch.aten.conj_physical, however I've had issues lowering the operator. My lowering works in all the trivial situations where it is essentially a no-op, for example taking a conjugate of a real number, or an integer. However, when I have to lower conj_physical into the MLIR operator complex.conj, I seem to be running into an issue, this is the test (in the commit) which fails:

class ElementwiseConjPhysicalComplexModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1, -1], torch.complex64, True),
        ]
    )
    def forward(self, a):
        return torch.ops.aten.conj_physical(a)

@register_test_case(module_factory=lambda: ElementwiseConjPhysicalComplexModule())
def ElementwiseConjPhysicalComplexModule_basic(module, tu: TestUtils):
    module.forward(torch.view_as_complex(tu.rand(3, 4, 2)))

This is the output:

TORCH_VERSION_FOR_COMPARISON = 2.5.0.dev20240718
Running tests sequentially with progress status
*** RUNNING TEST: ElementwiseConjPhysicalComplexModule_basic ***
Compiling ElementwiseConjPhysicalComplexModule_basic...

====================
TorchScript RAW IR
module attributes {torch.debug_module_name = "ElementwiseConjPhysicalComplexModule"} {
  func.func private @__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule.forward(%arg0: !torch.nn.Module<"__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?,?],complex<f32>>}) -> !torch.tensor {
    %1 = torch.aten.conj_physical %arg1 : !torch.tensor -> !torch.tensor
    return %1 : !torch.tensor
  }
  torch.class_type @__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule {
    torch.attr private "training" : !torch.bool
    torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
    torch.method "forward", @__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule.forward
  }
  %true = torch.constant.bool true
  %none = torch.constant.none
  %0 = torch.nn_module {
    torch.slot "training", %true : !torch.bool
    torch.slot "_is_full_backward_hook", %none : !torch.none
  } : !torch.nn.Module<"__torch__.torch_mlir_e2e_test.test_suite.elementwise.ElementwiseConjPhysicalComplexModule">
}

====================
Torch Backend IR
module attributes {torch.debug_module_name = "ElementwiseConjPhysicalComplexModule"} {
  func.func @forward(%arg0: !torch.vtensor<[?,?],complex<f32>>) -> !torch.vtensor<[?,?],complex<f32>> {
    %0 = torch.aten.conj_physical %arg0 : !torch.vtensor<[?,?],complex<f32>> -> !torch.vtensor<[?,?],complex<f32>>
    return %0 : !torch.vtensor<[?,?],complex<f32>>
  }
}

====================
LINALG Backend IR
module attributes {torch.debug_module_name = "ElementwiseConjPhysicalComplexModule"} {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @forward(%arg0: tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>> {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xcomplex<f32>>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xcomplex<f32>>
    %0 = tensor.empty(%dim, %dim_0) : tensor<?x?xcomplex<f32>>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xcomplex<f32>>) outs(%0 : tensor<?x?xcomplex<f32>>) {
    ^bb0(%in: complex<f32>, %out: complex<f32>):
      %2 = complex.conj %in : complex<f32>
      linalg.yield %2 : complex<f32>
    } -> tensor<?x?xcomplex<f32>>
    return %1 : tensor<?x?xcomplex<f32>>
  }
}

TORCH_VERSION_FOR_COMPARISON = 2.5.0.dev20240718
FAIL - "ElementwiseConjPhysicalComplexModule_basic"

Unexpected outcome summary: (linalg)

****** Failed tests - 1 tests
    FAIL - "ElementwiseConjPhysicalComplexModule_basic"
        Compilation error: Traceback (most recent call last):
          File "/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 313, in compile_and_run_test
            compiled = config.compile(test.program_factory(), verbose=verbose)
          File "/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py", line 38, in compile
            return self.backend.compile(module)
          File "/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py", line 227, in compile
            run_pipeline_with_repro_report(
          File "/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 78, in run_pipeline_with_repro_report
            raise TorchMlirCompilerError(trimmed_message) from None
        torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Linalg-on-Tensors IR to LLVM with RefBackend failed with the following diagnostics:
        error: failed to legalize operation 'complex.conj' that was explicitly marked illegal
        note: see current operation: %54 = "complex.conj"(%53) <{fastmath = #arith.fastmath<none>}> : (complex<f32>) -> complex<f32>

        python exception: Failure while executing pass pipeline

        For Torch-MLIR developers, the error can be reproduced with:
        $ torch-mlir-opt -pass-pipeline='builtin.module(func.func(linalg-generalize-named-ops),func.func(linalg-fuse-elementwise-ops),convert-shape-to-std,sparse-assembler{direct-out},sparsification-and-bufferization,sparse-storage-specifier-to-llvm,func.func(expand-realloc),func.func(refback-generalize-tensor-pad),func.func(refback-generalize-tensor-concat),func.func(tm-tensor-bufferize),one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map},refback-mlprogram-bufferize,func.func(finalizing-bufferize),func.func(buffer-deallocation),inline,refback-munge-calling-conventions,func.func(tm-tensor-to-loops),func.func(refback-munge-memref-copy),func.func(convert-linalg-to-loops),func.func(lower-affine),convert-scf-to-cf,func.func(refback-expand-ops-for-llvm),func.func(arith-expand),func.func(convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,lower-affine,convert-bufferization-to-memref,finalize-memref-to-llvm,func.func(convert-arith-to-llvm),convert-vector-to-llvm,convert-func-to-llvm,convert-cf-to-llvm,convert-complex-to-llvm,reconcile-unrealized-casts)' /tmp/ElementwiseConjPhysicalComplexModule.mlir
        Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

Summary:
    Failed: 1

When I try to run the long command right above to reproduce the error, I get a different error, related to the pass arguments in the command itself:

<unknown>:0: error: MLIR Textual PassPipeline Parser:1:11: error: 'linalg-generalize-named-ops' does not refer to a registered pass or pass pipeline

pkapris-syrmia avatar Aug 08 '24 07:08 pkapris-syrmia

I've discovered I need to run this command to recreate the issue:

</tmp/ElementwiseConjPhysicalComplexModule.mlir mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(linalg-generalize-named-ops),func.func(linalg-fuse-elementwise-ops),convert-shape-to-std,sparse-assembler{direct-out},sparsification-and-bufferization,sparse-storage-specifier-to-llvm,func.func(expand-realloc))' | torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(refback-generalize-tensor-pad),func.func(refback-generalize-tensor-concat),func.func(tm-tensor-bufferize))' | mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map})' | torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(refback-mlprogram-bufferize)' | mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(finalizing-bufferize),func.func(buffer-deallocation),inline)' | torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(refback-munge-calling-conventions,func.func(tm-tensor-to-loops),func.func(refback-munge-memref-copy))' | mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(convert-linalg-to-loops),func.func(lower-affine),convert-scf-to-cf)' | torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(refback-expand-ops-for-llvm))' | mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(func.func(arith-expand),func.func(convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,lower-affine,convert-bufferization-to-memref,finalize-memref-to-llvm,func.func(convert-arith-to-llvm),convert-vector-to-llvm,convert-func-to-llvm,convert-cf-to-llvm,convert-complex-to-llvm,reconcile-unrealized-casts)'

This is the input file to the command:

#loc = loc(unknown)
#loc1 = loc("/home/[email protected]/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/elementwise.py":4848:15)
#map = affine_map<(d0, d1) -> (d0, d1)>
#loc2 = loc("aten::conj_physical"(#loc1))
module attributes {torch.debug_module_name = "ElementwiseConjPhysicalComplexModule"} {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64> loc(#loc)
  func.func @forward(%arg0: tensor<?x?xcomplex<f32>> loc(unknown)) -> tensor<?x?xcomplex<f32>> {
    %c1 = arith.constant 1 : index loc(#loc)
    %c0 = arith.constant 0 : index loc(#loc)
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xcomplex<f32>> loc(#loc2)
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xcomplex<f32>> loc(#loc2)
    %0 = tensor.empty(%dim, %dim_0) : tensor<?x?xcomplex<f32>> loc(#loc2)
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xcomplex<f32>>) outs(%0 : tensor<?x?xcomplex<f32>>) {
    ^bb0(%in: complex<f32> loc("aten::conj_physical"(#loc1)), %out: complex<f32> loc("aten::conj_physical"(#loc1))):
      %2 = complex.conj %in : complex<f32> loc(#loc2)
      linalg.yield %2 : complex<f32> loc(#loc2)
    } -> tensor<?x?xcomplex<f32>> loc(#loc2)
    return %1 : tensor<?x?xcomplex<f32>> loc(#loc)
  } loc(#loc)
} loc(#loc)

This is the (very long) output: output.txt

pkapris-syrmia avatar Aug 09 '24 11:08 pkapris-syrmia

The convert-complex-to-llvm pass doesn't seem to have a pattern for complex.conj, you might have to run it through convert-complex-to-standard first.

ubfx avatar Aug 09 '24 14:08 ubfx

@ubfx I have tried doing this. I found the definition for the lowering pipeline, and added the marked line. The test passes now. I have also run the build_tools/ci/test_posix, and apparently all the other linalg tests still work with this pass.

Should this pass then be added to the Torch-MLIR project? And if so, should I add it into this PR?

LOWERING_PIPELINE = (
    "builtin.module("
    + ",".join(
        [
            # Apply some optimizations. It would be great if MLIR had more useful
            # optimizations that worked out of the box here.
            # Note: When measured, this doesn't seem to actually help that much
            # for the linalg-on-tensors backend.
            # This is likely because if things are naturally fusable we usually already
            # emit things in that form from the high level (e.g. single linalg-generic).
            # Other backends are likely to benefit more.
            "func.func(linalg-generalize-named-ops)",
            "func.func(linalg-fuse-elementwise-ops)",
            "convert-shape-to-std",
            # MLIR Sparsifier mini-pipeline. Note that this is the bare minimum
            # to ensure operations on sparse tensors are lowered to loops.
            "sparse-assembler{direct-out}",
            "sparsification-and-bufferization",
            "sparse-storage-specifier-to-llvm",
            # Buffer deallocation pass does not know how to handle realloc.
            "func.func(expand-realloc)",
            # Generalize pad and concat after sparse compiler, as they are handled
            # differently when the operations involve sparse operand.
            "func.func(refback-generalize-tensor-pad)",
            "func.func(refback-generalize-tensor-concat)",
            # Bufferize.
            "func.func(tm-tensor-bufferize)",
            "one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}",
            "refback-mlprogram-bufferize",
            "func.func(finalizing-bufferize)",
            "func.func(buffer-deallocation)",
            # Buffer-deallocation does not work with the inlined code generated
            # by sparse tensor dialect.
            "inline",  # inline sparse helper methods where useful
            # Munge to make it ExecutionEngine compatible.
            # Specifically, we rewrite calling convention boundaries to be in terms
            # of unranked memref, and we rewrite the return to actually be a
            # callback that consumes the return (the final munged function always
            # returns void at the C level -- we get the return value by providing the
            # callback).
            "refback-munge-calling-conventions",
            # Insert global variable and instruction sequence for getting the next
            # global seed used in stateful rng.
            # Lower to LLVM
            "func.func(tm-tensor-to-loops)",
            "func.func(refback-munge-memref-copy)",
            "func.func(convert-linalg-to-loops)",
            "func.func(lower-affine)",


            # THE LINE BELOW HAS BEEN ADDED
            "convert-complex-to-standard",


            "convert-scf-to-cf",
            "func.func(refback-expand-ops-for-llvm)",
            "func.func(arith-expand)",
            "func.func(convert-math-to-llvm)",
            # Handle some complex mlir::math ops (e.g. atan2)
            "convert-math-to-libm",
            "expand-strided-metadata",
            "finalize-memref-to-llvm",
            "lower-affine",
            "convert-bufferization-to-memref",
            "finalize-memref-to-llvm",
            "func.func(convert-arith-to-llvm)",
            "convert-vector-to-llvm",
            "convert-func-to-llvm",
            "convert-cf-to-llvm",
            "convert-complex-to-llvm",
            "reconcile-unrealized-casts",
        ]
    )
    + ")"
)

pkapris-syrmia avatar Aug 13 '24 10:08 pkapris-syrmia

Did you experiment with the placement of the pass within the pipeline? There might be advantages to placing it a bit further back in the pipeline.

Generally, I can't see anything wrong with adding an upstream MLIR pass to the pipeline. I would add it to this PR since it is required to make the complex.conj lowering work. The reviewers can chip in in case there are specific considerations about changing the pipeline.

ubfx avatar Aug 13 '24 10:08 ubfx

I did, at first I placed it close to the and, before "convert-complex-to-llvm", however it seems that's a bit too late, so I placed it further back, I'm not sure what the optimal position would be

pkapris-syrmia avatar Aug 13 '24 10:08 pkapris-syrmia