iree icon indicating copy to clipboard operation
iree copied to clipboard

Add pass to generalize pack ops if they are consumed by flow.dispatch.tensor.store ops

Open nirvedhmeshram opened this issue 5 months ago • 7 comments

Pack ops can affect tiling decisions and hence it is beneficial to generalize them, for e.g for below IR

    %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%4 : tensor<8x768xf32>) {
    ^bb0(%out: f32):
      %6 = linalg.index 0 : index
      %7 = linalg.index 1 : index
      %extracted = tensor.extract %2[%6, %c0, %7] : tensor<8x128x768xf32>
      linalg.yield %extracted : f32
    } -> tensor<8x768xf32>
    %pack = tensor.pack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %3 : tensor<8x768xf32> -> tensor<1x768x8x1xf32>

The gather linalg.generic will get the following tiling config {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[8, 512], [8, 1], [0, 0], [0, 0]]>} This is not supported by current upstream vectorization https://github.com/llvm/llvm-project/issues/107476 (it should be and that is being worked on but we didnt need to reach this tiling config)

With this PR the above IR will simplify to

  %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%3 : tensor<8x768xf32>) {
  ^bb0(%out: f32):
    %6 = linalg.index 0 : index
    %7 = linalg.index 1 : index
    %extracted = tensor.extract %2[%6, %c128, %7] : tensor<8x128x768xf32>
    linalg.yield %extracted : f32
  } -> tensor<8x768xf32>
  %5 = tensor.empty() : tensor<768x8xf32>
  %transposed = linalg.transpose ins(%4 : tensor<8x768xf32>) outs(%5 : tensor<768x8xf32>) permutation = [1, 0] 

with which we will get the following vectorizable tiling config {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 8], [1, 4], [0, 0], [0, 0]]>}

We only want to do this when the pack is being consumed only by flow.dispatch.tensor.store and we dont do any transforms for other cases as that would have implications on codegeneration that we dont want.

Fixes : https://github.com/iree-org/iree/issues/18413

nirvedhmeshram avatar Sep 17 '24 16:09 nirvedhmeshram