mlir-hlo icon indicating copy to clipboard operation
mlir-hlo copied to clipboard

MHLO operation regions need to use scalars arguments

Open MaheshRavishankar opened this issue 4 years ago • 6 comments

MHLO operations that have regions use a zero-rank tensor to represent what are really scalar values. For example

func @reduce_one_op_all_locs_same(%arg0: tensor<?x?xf32>, %arg1 : tensor<f32>) -> (tensor<?xf32>) {
  %0 = "mhlo.reduce"(%arg0, %arg1) ( {
  ^bb0(%arg2: tensor<f32> loc("foo"), %arg3: tensor<f32> loc("foo")):
    %1 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32> loc("foo")
    "mhlo.return"(%1) : (tensor<f32>) -> () loc("foo")
  }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")

  return %0: tensor<?xf32>
}

There are a couple of issues here.

  1. The region of the mhlo.reduce here has an mhlo.add. The way one would lower mhlo.add to say linalg dialect is very different whether this operation is within an mhlo op or at the top level. This seems to be a conflation between different uses of an mhlo.add operation. It would be much easier to handle this if mhlo.add was only used at the top level and a different operation was used within mhlo operations.
  2. The region of the mhlo operation in this case seems to be a sequence of computations that are really scalars. Using tensor of zero rank introduces additional complexity when translating this to Linalg dialect since this requires a type conversion of the arguments from zero rank tensor to scalars. Having this scalar before the conversion would reduce a lot of the complexity.

MaheshRavishankar avatar Dec 08 '21 19:12 MaheshRavishankar

Not all reduction are scalars though. The zero-rank is just the degenerated case, but take for example (from the test-suite):

func @reduce_valid(%arg0: tensor<4x4xf32>, %arg1 : tensor<4xf32>)
    -> (tensor<4xf32>) {
  %0 = "mhlo.reduce"(%arg0, %arg1) ( {
  ^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ):
    %1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
    "mhlo.return"(%1) : (tensor<4xf32>) -> ()

  }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>

  return %0: tensor<4xf32>
}

joker-eph avatar Dec 08 '21 20:12 joker-eph

Not all reduction are scalars though. The zero-rank is just the degenerated case, but take for example (from the test-suite):

func @reduce_valid(%arg0: tensor<4x4xf32>, %arg1 : tensor<4xf32>)
    -> (tensor<4xf32>) {
  %0 = "mhlo.reduce"(%arg0, %arg1) ( {
  ^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ):
    %1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
    "mhlo.return"(%1) : (tensor<4xf32>) -> ()

  }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>

  return %0: tensor<4xf32>
}

Yes, I did notice that (and actually didnt know that this existed). Specifically such an operation cannot be lowered to Linalg directly (today). So maybe all that is needed is an MHLO -> MHLO transform before lowering to Linalg that converts the zero-rank tensor case to scalars and converts mhlo.add operations within such regions to say arith operations. I think today those would be marked illegal?

MaheshRavishankar avatar Dec 08 '21 20:12 MaheshRavishankar

Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).

silvasean avatar Dec 08 '21 22:12 silvasean

I don't think you need the dialect conversion framework to handle this. I think the biggest issue is what operations are supported in the MHLO reduce region. I could easily see non-elementwise operations being used in the reduction region preventing lowering to linalg.

rsuderman avatar Dec 08 '21 22:12 rsuderman

Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).

Not a stakeholder in MHLO per se, but for me mhlo.reduce having a payload that is itself a tensor based operation is a "higher level abstraction". That needs to be lowered into "something else" before it can be lowered into Linalg. Simply speaking, the payload operating on tensors makes the mhlo.reduce an imperfectly nested loop nest, while it operating on scalars is a perfectly nested loop nest. A perfectly nested loop nest is a special case of an imperfectly nested loop nest, but lowering an imperfectly nested loop nest is a different starting point compared to lowering a perfectly nested loop nest.

MaheshRavishankar avatar Dec 08 '21 22:12 MaheshRavishankar

Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).

Not a stakeholder in MHLO per se, but for me mhlo.reduce having a payload that is itself a tensor based operation is a "higher level abstraction". That needs to be lowered into "something else" before it can be lowered into Linalg. Simply speaking, the payload operating on tensors makes the mhlo.reduce an imperfectly nested loop nest, while it operating on scalars is a perfectly nested loop nest. A perfectly nested loop nest is a special case of an imperfectly nested loop nest, but lowering an imperfectly nested loop nest is a different starting point compared to lowering a perfectly nested loop nest.

So if `mhlo.reduce` does support tensor operations in the payload, there needs to be further mhlo -> mhlo transformations that would be needed to get it to state where it can be lowered to Linalg (as an example).

MaheshRavishankar avatar Dec 08 '21 23:12 MaheshRavishankar