MHLO operation regions need to use scalars arguments
MHLO operations that have regions use a zero-rank tensor to represent what are really scalar values. For example
func @reduce_one_op_all_locs_same(%arg0: tensor<?x?xf32>, %arg1 : tensor<f32>) -> (tensor<?xf32>) {
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32> loc("foo"), %arg3: tensor<f32> loc("foo")):
%1 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32> loc("foo")
"mhlo.return"(%1) : (tensor<f32>) -> () loc("foo")
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")
return %0: tensor<?xf32>
}
There are a couple of issues here.
- The region of the
mhlo.reducehere has anmhlo.add. The way one would lowermhlo.addto saylinalgdialect is very different whether this operation is within anmhloop or at the top level. This seems to be a conflation between different uses of anmhlo.addoperation. It would be much easier to handle this ifmhlo.addwas only used at the top level and a different operation was used withinmhlooperations. - The region of the
mhlooperation in this case seems to be a sequence of computations that are really scalars. Using tensor of zero rank introduces additional complexity when translating this toLinalgdialect since this requires a type conversion of the arguments from zero rank tensor to scalars. Having this scalar before the conversion would reduce a lot of the complexity.
Not all reduction are scalars though. The zero-rank is just the degenerated case, but take for example (from the test-suite):
func @reduce_valid(%arg0: tensor<4x4xf32>, %arg1 : tensor<4xf32>)
-> (tensor<4xf32>) {
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ):
%1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
"mhlo.return"(%1) : (tensor<4xf32>) -> ()
}) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0: tensor<4xf32>
}
Not all reduction are scalars though. The zero-rank is just the degenerated case, but take for example (from the test-suite):
func @reduce_valid(%arg0: tensor<4x4xf32>, %arg1 : tensor<4xf32>) -> (tensor<4xf32>) { %0 = "mhlo.reduce"(%arg0, %arg1) ( { ^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ): %1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> "mhlo.return"(%1) : (tensor<4xf32>) -> () }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0: tensor<4xf32> }
Yes, I did notice that (and actually didnt know that this existed). Specifically such an operation cannot be lowered to Linalg directly (today). So maybe all that is needed is an MHLO -> MHLO transform before lowering to Linalg that converts the zero-rank tensor case to scalars and converts mhlo.add operations within such regions to say arith operations. I think today those would be marked illegal?
Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).
I don't think you need the dialect conversion framework to handle this. I think the biggest issue is what operations are supported in the MHLO reduce region. I could easily see non-elementwise operations being used in the reduction region preventing lowering to linalg.
Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).
Not a stakeholder in MHLO per se, but for me mhlo.reduce having a payload that is itself a tensor based operation is a "higher level abstraction". That needs to be lowered into "something else" before it can be lowered into Linalg. Simply speaking, the payload operating on tensors makes the mhlo.reduce an imperfectly nested loop nest, while it operating on scalars is a perfectly nested loop nest. A perfectly nested loop nest is a special case of an imperfectly nested loop nest, but lowering an imperfectly nested loop nest is a different starting point compared to lowering a perfectly nested loop nest.
Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something).
Not a stakeholder in MHLO per se, but for me
mhlo.reducehaving a payload that is itself a tensor based operation is a "higher level abstraction". That needs to be lowered into "something else" before it can be lowered into Linalg. Simply speaking, the payload operating on tensors makes themhlo.reducean imperfectly nested loop nest, while it operating on scalars is a perfectly nested loop nest. A perfectly nested loop nest is a special case of an imperfectly nested loop nest, but lowering an imperfectly nested loop nest is a different starting point compared to lowering a perfectly nested loop nest.