llvm-project icon indicating copy to clipboard operation
llvm-project copied to clipboard

[mlir][scf] Extend consumer fuse to nested loop structure

Open Yun-Fly opened this issue 1 year ago • 9 comments

Hi, based on early discussion in this thread. This patch aims to extend new feature of fusing consumer to more complex nested loop structure. E.g.

#map = affine_map<(d0) -> (d0 * 128)>
module {
  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c128 = arith.constant 128 : index
    %cst = arith.constant 0.000000e+00 : f32
    %dest0 = tensor.empty() : tensor<256x256xf32>
    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
      %iv0 = affine.apply #map(%arg3)
      %iv1 = affine.apply #map(%arg4)
      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
          scf.yield %insert_slice : tensor<128x128xf32>
        }
        scf.yield %3 : tensor<128x128xf32>
      }
      scf.forall.in_parallel {
         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
      }
    }
    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    return %5 : tensor<256x256xf32>
  }
}

What's New in this PR:

  1. support nested loop structure, including both scf.for and scf.forall.
  2. support multi-level insert_slice or parallel_insert_slice.

NOTE that: this PR DOES NOT deal with the refactor of getTiledImplementation we have talked before but just focuses on the functionality enhancement, BTW, in above example, you can also find that the similar issue related to unmatched semantic between tiled operand and assumption of current getTiledImplementation even on dpsInits. To unblock this necessary patch, I temporarily follow the method as @MaheshRavishankar suggested, using dummy insert_slice to align those gap.

The resulting IR will finally appear like below:

#map = affine_map<(d0) -> (d0 * 128)>
#map1 = affine_map<(d0, d1) -> (d0 + d1 * 128)>
module {
  module {
    func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
      %c0 = arith.constant 0 : index
      %c64 = arith.constant 64 : index
      %c128 = arith.constant 128 : index
      %cst = arith.constant 0.000000e+00 : f32
      %0 = tensor.empty() : tensor<256x256xf32>
      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
      %2:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %1, %arg6 = %0) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
        %3 = affine.apply #map(%arg3)
        %4 = affine.apply #map(%arg4)
        %extracted_slice = tensor.extract_slice %arg5[%3, %4] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
        %extracted_slice_0 = tensor.extract_slice %arg0[%3, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
        %extracted_slice_1 = tensor.extract_slice %arg1[0, %4] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
        %extracted_slice_2 = tensor.extract_slice %arg6[%3, %4] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
        %5:2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %extracted_slice, %arg9 = %extracted_slice_2) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
          %6:2 = scf.for %arg10 = %c0 to %c128 step %c64 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
            %extracted_slice_3 = tensor.extract_slice %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
            %extracted_slice_4 = tensor.extract_slice %extracted_slice_0[%arg7, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
            %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[0, %arg10] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
            %7 = linalg.matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
            %8 = affine.apply #map1(%arg7, %arg3)
            %9 = affine.apply #map1(%arg10, %arg4)
            %extracted_slice_6 = tensor.extract_slice %arg2[%8, %9] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
            %extracted_slice_7 = tensor.extract_slice %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
            %10 = linalg.add ins(%7, %extracted_slice_6 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%extracted_slice_7 : tensor<64x64xf32>) -> tensor<64x64xf32>
            %inserted_slice = tensor.insert_slice %7 into %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
            %inserted_slice_8 = tensor.insert_slice %10 into %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
            scf.yield %inserted_slice, %inserted_slice_8 : tensor<128x128xf32>, tensor<128x128xf32>
          }
          scf.yield %6#0, %6#1 : tensor<128x128xf32>, tensor<128x128xf32>
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
          tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
        }
      }
      return %2#1 : tensor<256x256xf32>
    }
  }
}

Looking forward to your suggestion and review, thanks.

Yun-Fly avatar Jun 03 '24 07:06 Yun-Fly