iree
iree copied to clipboard
Improve FuseMultiUseElementwiseProducersPass
The helper function isHorizontalToGroup
relies on getBackwardSlice
which doesn't include operations defined above which are used in the body of an operation (vs as an operand) https://github.com/iree-org/iree/blob/114a1427810f3da0234f98c22f58390773b0489a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp#L104
To be conservative, isHorizontalToGroup
returns false when there is a value defined above. This could be improved to track values defined above and add them to the slice. This would also require a corresponding change to moveOperandDefs
here (so that values defined above are also moved): https://github.com/iree-org/iree/blob/114a1427810f3da0234f98c22f58390773b0489a/compiler/src/iree/compiler/DispatchCreation/FusionUtils.h#L37
This would allow for cases like the following to be fused:
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %cst : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
// Note %cst_0 defined above
%2 = arith.subf %in, %cst_0 : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
}
}