iree icon indicating copy to clipboard operation
iree copied to clipboard

[LLVMCPU] Bad packing codegen with different `outer_dims_perm`

Open Max191 opened this issue 1 year ago • 2 comments

Some tensor.pack and tensor.unpack ops become significantly slower with different outer_dims_perm values. The following gist has a bad tensor.pack and tensor.unpack case, as well as a good tensor.pack case: https://gist.github.com/Max191/a32c07b72272e74cf625cd810ae09c0a

Compile with

iree-compile packing.mlir \
  --iree-hal-target-backends=llvm-cpu \
  --iree-llvmcpu-target-cpu=znver4 \
  --iree-llvmcpu-enable-ukernels=mmt4d \
  -o /tmp/packing.vmfb

This gist also shows the difference: https://gist.github.com/Max191/2d6a74f4f7be1951ac359b6fd8db60ca One of the benchmarks is an unpack + transpose, and the other is a pure unpack that does the same thing. There are differences in tile size selection here. These benchmarks can be compiled with:

iree-compile packing.mlir \
  --iree-hal-target-backends=llvm-cpu \
  --iree-llvmcpu-target-cpu=znver4 \
  --iree-llvmcpu-enable-ukernels=mmt4d \
  --compile-from=executable-sources \
  -o /tmp/packing.vmfb

Max191 avatar May 21 '24 19:05 Max191

The pack issue is temporarily covered by ukernel. Let's focus on unpack kernel in the issue. There are transpose variants in the unpack op. We need to take it into account now. @pashu123 please take a stab at it.

  func.func @unpack_bad(%arg0: tensor<64x1828x8x16x16xf32>) -> tensor<29241x128x64xf32> {
    %cst = arith.constant 0.000000e+00 : bf16
    %4 = tensor.empty() : tensor<29241x128x64xf32>
    %unpack = tensor.unpack %arg0 outer_dims_perm = [2, 0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %4 : tensor<64x1828x8x16x16xf32> -> tensor<29241x128x64xf32>
    return %unpack : tensor<29241x128x64xf32>
  }

hanhanW avatar May 22 '24 00:05 hanhanW

Putting a note here. I think the current plan is

  1. Enable unpack ukernels
  2. Learn performance gap
  3. Plan out the work for unpack codegen.

hanhanW avatar May 23 '24 21:05 hanhanW