torch-mlir
torch-mlir copied to clipboard
Avoid doing bias-add by setting the bias value as the `outs` operand
Looking at the IR generated from Torch-MLIR within IREE, after some fusion, I see these kind of patterns
%114 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_18 : tensor<384xf32>) outs(%49 : tensor<1x128x384xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<1x128x384xf32>
%115 = linalg.batch_matmul ins(%113, %cst_182 : tensor<1x128x384xf32>, tensor<1x384x384xf32>) outs(%114 : tensor<1x128x384xf32>) -> tensor<1x128x384xf32>
If I am reading this correctly, this is a batch_matmul
followed by bias add computation that is written as a broadcast of the bias into the output shape of the batch_matmul
followed by the batch_matmul
. Not sure this is the best way to represent the computation, it definitely trips up fusion at Linalg level. A better representation would be
%fill = linalg.fill ins(%cst_zero : f32) outs (%114 : tensor<1x128x384xf32>) -> tensor<1x128x384xf32>
%115 = linalg.batch_matmul ins(%113, %cst_182 : tensor<1x128x384xf32>, tensor<1x384x384xf32>) outs(%fill : tensor<1x128x384xf32>) -> tensor<1x128x384xf32>
%cst = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>],
iterator_types = ["parallel", "parallel", "parallel"] {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield : f32
} -> tensor<1x128x384xf32>
At a very preliminary level this representation avoids the explicit broadcast of %cst_182
(and FWIW the Tensorflow MLIR lowering of op + bias-add is done this way). I tend to think of this as a more canonical representation of the computation here.
Thanks Mahesh, somehow the former IR naively felt more minimal at the time, but thanks for the feedback! This is why we co-design :)
This looks like a part of AtenLinearOp
lowering to linalg. I can modify the same lowering(linalg conversion pass) to separate out the bias addition, but moving the lowering to decomposition pass(AtenMatMul
+ AtenAdd
) seems a better and clean approach. What do you think @silvasean @MaheshRavishankar ?
Seems fine to decompose it.
There is an open patch for the decomposition: https://github.com/llvm/torch-mlir/pull/862
CI fails for this PR because the aten.matmul
op does not handle higher dimensional cases. A specific test case of (3D,2D) input fails here. I am trying to handle the cases for aten.matmul
where at least one matrix is 3D.
Possible approaches in my mind:
a) lower to linalg.batchmatmul
: 1. broadcast(batch dimensions) the less rank matrix, 2. collapse the batch dimensions 3. matrix multiply by linalg.batchmatmul 4. expand the batch dimensions
Although this approach seems to be efficient, the 4th step will create issues for dynamic dimensions AFAIK
b) lower to linalg.generic
: This seems to be an inefficient approach.
What do you think @silvasean @MaheshRavishankar ?
@ThomasRaoux
a) is the approach I imagined.
Another related issue: https://github.com/llvm/torch-mlir/issues/879
I am not sure I follow (a) fully. If the case is 3D LHS, 2D RHS, I would expect it to be lowered as
- Broadcast the 2D RHS to 3D RHS
- Use batch-matmul
I dont understand the "collapse batch dimension" and the "expand the batch dimensions" part of (a) above.
This has a down-side though. The broadcast from the 2D RHS to 3D RHS will have to be materialized in memory. Thats both a computation and a memory cost. It would be interesting to see if just using a linalg.generic
works here. In the final state of things a linalg.matmul
/linalg.batch_matmul
and linalg.generic
that express the same computation should end up being handled the same way, but we might not be there yet.
I dont understand the "collapse batch dimension" and the "expand the batch dimensions" part of (a) above.
torch's matmul allows arbitrary leading batch dimensions and combinations of broadcasting. e.g. [42,2,1,4,5,6] x [1,3,4,6,7]
. So in general we need to resolve all of that down to a single leading batch dimension for linalg.batch_matmul.
It would be nice if linalg could be improved so that these broadcasts aren't materialized in memory.
@Shukla-Gaurav @erman-gurses anyone working on this?
The fix for this should roll in the change that @makslevental made here: https://github.com/llvm/torch-mlir/pull/919
@silvasean I am working on this, will take care of #919 also. thanks!