[DataTiling] Support more consumer fusions cases in data tiling fusion path
This issue is about tracking work for supporting more complicated consumer fusion cases in the data tiling fusion path. More concretely, this issue is tracking support for fusing consumers with multiple input operands.
With https://github.com/iree-org/iree/pull/20901 and https://github.com/iree-org/iree/pull/20898, we have enough propagation to be able to handle simple unary element-wise consumers by propagating data tiling ops after materialization. This works because there are no additional operands on the consumer that need to be packed/swizzled into the new layout, so the consumer can simply be rewritten in the new layout. The more complicated case is when the consumer has additional operands that also need to be in the data tiled layout. For example:
...
%data_tiled_multi_mma = iree_gpu.multi_mma ... tensor<?x?x4x8x2x4x16x4xf32>
%transposed = linalg.transpose ins(%data_tiled_multi_mma ... permutation = [0, 1, 5, 3, 7, 2, 6, 4] ... tensor<?x?x4x8x4x4x16x2xf32>
%collapsed = tensor.collapse_shape %transposed [[0], [1], [2, 3, 4], [5, 6, 7]] ... into tensor<?x?x128x128xf32>
%unpack = linalg.unpack %collapsed outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [128, 128] ... -> tensor<?x?xf32>
%consumer = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%unpack : tensor<?x?xf32>, %bias : tensor<?x?xf32>) outs(%init : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%add = arith.addf %in, %in_0 : f32
linalg.yield %add : f32
} -> tensor<?x?xf32>
...
// Can be more layout transformation following the %consumer
The %consumer needs to be directly consuming the %data_tiled_multi_mma if we want any hope of keeping the values in register from the multi_mma to the consumer. This means we need to propagate all of the layout transformation (tensor.collapse_shape and linalg.unpack at least) past the consumer op. In order to do this, the consumer needs to be rewritten to the data tiled layout, and the %bias operand needs to be transformed into a compatible layout. This means we need to add additional layout transformation (linalg.pack, tensor.expand_shape) to bring it to the same layout, and we need to be able to generate code for this sequence.
I see 2 reasonable options for this, both of which require significant work. We also may ultimately want to use both solutions in practice.
MapGatherOp
The first option is to introduce an iree_linalg_ext.map_gather op to mirror the existing iree_linalg_ext.map_scatter, and the layout transformation on the %bias operand can be combined into a map_gather. This would make generating code simple for the fusion after we propagate. This requires introducing the map_gather op, and plumbing it through codegen similarly to map_scatter. There may also be unforeseen difficulties, because the shape of the program graph is unlike things we typically see, i.e., we have:
root_op
\ map_gather
\ /
consumer
|
map_scatter
This probably won't be much of an issue, but it is not well tested, so we will find out.
Encoding Propagation + Hoisting
The alternative option is to try and hoist the layout transformation out of the dispatch, and move it into a producer dispatch. This means that the propagation needs to happen when we are forming dispatches, when the transformation is still represented by encodings. This makes codegen simple, because everything is already in the data tiled layout, but there is plenty of work to make the propagation + hoisting + fusion work at dispatch creation level.
Ultimately, we probably want a combination of both solutions, because it may not always be beneficial to hoist the layout transformation on %bias out of the dispatch region, and then we will need to codegen it.
cc @hanhanW