iree icon indicating copy to clipboard operation
iree copied to clipboard

Tracking issue for pad-based encoding

Open MaheshRavishankar opened this issue 7 months ago • 24 comments

https://github.com/MaheshRavishankar/iree/tree/users/MaheshRavishankar/padEncodinge2e is the branch that I have been working on for enabling the pad based encoding. I have been using the following example as a prototype

util.func @test(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x2048xf32>,
    %arg2 : tensor<2048x?xf32>) -> tensor<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %M = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %K1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
  %N = tensor.dim %arg2, %c1 : tensor<2048x?xf32>
  %0 = tensor.empty(%M) : tensor<?x2048xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x2048xf32>) -> tensor<?x2048xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x2048xf32>)
      outs(%1 : tensor<?x2048xf32>) -> tensor<?x2048xf32>
  %3 = tensor.empty(%M, %N) : tensor<?x?xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<?x?xf32>) -> tensor<?x?xf32>
  %5 = linalg.matmul ins(%2, %arg2 : tensor<?x2048xf32>, tensor<2048x?xf32>)
      outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
  util.return %5 : tensor<?x?xf32>
}

With that branch I know get a failure in Materialize Encoding. @jtuyls is working on addressing that, but filing an issue to record state of this and iterate on it (ill be away for a few days).

This is my compilation command

iree-compile test.mlir \
    --iree-hal-target-device=hip --iree-hip-target=gfx942 \
    -o test..vmfb \
    --iree-opt-level=O3 \
    --iree-opt-data-tiling=false \
    --iree-dispatch-creation-experimental-data-tiling \
    --iree-dispatch-creation-set-encoding-strategy=padding \
    --iree-hip-encoding-layout-resolver=pad 

MaheshRavishankar avatar May 16 '25 05:05 MaheshRavishankar

Another issue I am having is that this test currently segfaults with the following error

LLVM ERROR: can't create Attribute 'mlir::iree_compiler::IREE::Codegen::EncodingNopLayoutAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.

If I could fix this error, I could just land that branch in main.

MaheshRavishankar avatar May 16 '25 06:05 MaheshRavishankar

Another issue I am having is that this test currently segfaults with the following error

LLVM ERROR: can't create Attribute 'mlir::iree_compiler::IREE::Codegen::EncodingNopLayoutAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.

If I could fix this error, I could just land that branch in main.

I created a fix for this here: https://github.com/iree-org/iree/pull/20837, however, for this test, you likely want to add the GPU target to the #hal.device.target, so the GPU target and cache info is found:

iree.gpu.target = #iree_gpu.target<arch = "gfx942",
                                     features = "",
                                     wgp = <compute = fp32,
                                            storage =  b32,
                                            subgroup =  none,
                                            dot =  none,
                                            mma = [<MFMA_F32_16x16x4_F32>],
                                            subgroup_size_choices = [64],
                                            max_workgroup_sizes = [1024, 1024, 1024],
                                            max_thread_count_per_workgroup = 1024,
                                            max_workgroup_memory_bytes = 65536,
                                            max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>

jtuyls avatar May 16 '25 14:05 jtuyls

An update on this:

To get the above matmul prototype working, we need(ed) following changes/fixes:

  • https://github.com/iree-org/iree/pull/20969
  • https://github.com/iree-org/iree/pull/20971
  • https://github.com/iree-org/iree/pull/20845 -> This one is now ready to be reviewed for me.

While I do see improvement for the matmul prototype, I am not seeing the same when benchmarking matmul_transpose_b, (which I see in the LLama3.1 IR):

IR Baseline runtime (ms) Pad runtime (ms)
Matmul 107 93.5
Matmul_transpose_b 6.21 6.20

More changes are needed to run Llama3.1 e2e with pad encoding. The following branch compiles and returns numerically correct results, however it skips dispatches when the producer dispatch is attention as I still see errors there (WIP): https://github.com/jtuyls/iree/tree/padEncodinge2e-2 . However, I am not seeing an improvement over the baseline:

Sequence length Baseline runtime (ms) Pad runtime (ms)
2048 396 398

I think this might be because all matmuls are 'matmul_transpose_b' and as noted above, I don't see a performance difference for that case. I am not sure though why padding doesn't seem to help in the latter case, but I could use some thoughts on that?

cc @MaheshRavishankar @hanhanW

Notes and Repro IR

Notes:

  • Matmul inputs used are: 1024x128xf32, 2048x128xf32 and 4096x2048xf32. The exact command is:
iree-benchmark-module \
    --device="hip://0" \
    --device_allocator=caching \
    --hip_use_streams=true \
    --module="test_pad.vmfb" \
    --function=test \
    --input=1024x128xf32=1. \
    --input=128x2048xf32=1. \
    --input=2048x4096xf32=1. \
    --benchmark_repetitions=3
  • Baseline is generated with:
iree-compile matmul.mlir \
    --iree-hal-target-device=hip \
    --iree-hip-target=gfx942 \
    --iree-opt-level=O3 \
    -o matmul.vmfb
  • Pad is generated with:
iree-compile matmul.mlir \
    --iree-hal-target-device=hip \
    --iree-hip-target=gfx942 \
    --iree-opt-level=O3 \
    --iree-opt-data-tiling=false \
    --iree-dispatch-creation-experimental-data-tiling \
    --iree-dispatch-creation-set-encoding-strategy=padding \
    --iree-hip-encoding-layout-resolver=pad \
    -o matmul_pad.vmfb

Matmul IR:

util.func @test(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x2048xf32>,
    %arg2 : tensor<2048x4096xf32>) -> tensor<?x4096xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %M = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %K1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
  %0 = tensor.empty(%M) : tensor<?x2048xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x2048xf32>) -> tensor<?x2048xf32>
  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x2048xf32>)
      outs(%1 : tensor<?x2048xf32>) -> tensor<?x2048xf32>
  %3 = tensor.empty(%M) : tensor<?x4096xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
  %5 = linalg.matmul ins(%2, %arg2 : tensor<?x2048xf32>, tensor<2048x4096xf32>)
      outs(%4 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
  util.return %5 : tensor<?x4096xf32>
}

Matmul_transpose_b IR:

util.func @test(%arg0 : tensor<?x?xf32>, %arg1 : tensor<2048x?xf32>, %arg2 : tensor<4096x2048xf32>) -> tensor<?x4096xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %M = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %K1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
  %0 = tensor.empty(%M) : tensor<?x2048xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x2048xf32>) -> tensor<?x2048xf32>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<2048x?xf32>) outs(%1 : tensor<?x2048xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %42 = arith.mulf %in, %in_0 : f32
    %43 = arith.addf %out, %42 : f32
    linalg.yield %43 : f32
  } -> tensor<?x2048xf32>
  %extract = tensor.extract_slice %arg2[0, 0] [4096, 2048] [1, 1] : tensor<4096x2048xf32> to tensor<4096x2048xf32>
  %3 = tensor.empty(%M) : tensor<?x4096xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2, %extract : tensor<?x2048xf32>, tensor<4096x2048xf32>) outs(%4 : tensor<?x4096xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %42 = arith.mulf %in, %in_0 : f32
    %43 = arith.addf %out, %42 : f32
    linalg.yield %43 : f32
  } -> tensor<?x4096xf32>
  util.return %5 : tensor<?x4096xf32>
}

jtuyls avatar Jun 02 '25 19:06 jtuyls

cc @kuhar

MaheshRavishankar avatar Jun 02 '25 21:06 MaheshRavishankar

@jtuyls what are the padding amounts for these shapes?

1024x128xf32, 2048x128xf32 and 4096x2048xf32

kuhar avatar Jun 03 '25 16:06 kuhar

@jtuyls what are the padding amounts for these shapes?

1024x128xf32, 2048x128xf32 and 4096x2048xf32

The padding amount is [0, 32], not on those parameters, but on the activation in between the two matmuls. See this dump after MaterializeEncodingIntoPaddingPass:

// -----// IR Dump After MaterializeEncodingIntoPaddingPass (iree-codegen-materialize-encoding-into-padding) //----- //
func.func @test_dispatch_0_matmul_like_Dx2048xD_f32() {
  %cst = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %c32_i64 = arith.constant 32 : i64
  %0 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
  %2 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
  %3 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(3) : i32
  %4 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(4) : i32
  %5 = hal.interface.constant.load layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(5) : i32
  %6 = arith.extui %1 : i32 to i64
  %7 = arith.shli %6, %c32_i64 : i64
  %8 = arith.extui %0 : i32 to i64
  %9 = arith.ori %8, %7 : i64
  %10 = arith.index_castui %9 : i64 to index
  %11 = arith.extui %3 : i32 to i64
  %12 = arith.shli %11, %c32_i64 : i64
  %13 = arith.extui %2 : i32 to i64
  %14 = arith.ori %13, %12 : i64
  %15 = arith.index_castui %14 : i64 to index
  %16 = arith.extui %5 : i32 to i64
  %17 = arith.shli %16, %c32_i64 : i64
  %18 = arith.extui %4 : i32 to i64
  %19 = arith.ori %18, %17 : i64
  %20 = arith.index_castui %19 : i64 to index
  %21:3 = util.assume.int 
      %10<umin = 0, umax = 9007199254740991>, 
      %15<umin = 0, umax = 9007199254740991>, 
      %20<umin = 0, umax = 9007199254740991>
    : index, index, index
  %22 = iree_tensor_ext.dispatch.workload.ordinal %21#0, 0 : index
  %23 = iree_tensor_ext.dispatch.workload.ordinal %21#1, 1 : index
  %24 = iree_tensor_ext.dispatch.workload.ordinal %21#2, 2 : index
  %25 = hal.interface.binding.subspan layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32>>{%24, %22}
  %26 = hal.interface.binding.subspan layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2048x?xf32>>{%23}
  %27 = hal.interface.binding.subspan layout(<constants = 6, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x2080xf32>>{%24}
  %28 = iree_tensor_ext.dispatch.tensor.load %25, offsets = [0, 0], sizes = [%24, %22], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf32>>{%24, %22} -> tensor<?x?xf32>
  %29 = iree_tensor_ext.dispatch.tensor.load %26, offsets = [0, 0], sizes = [2048, %23], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2048x?xf32>>{%23} -> tensor<2048x?xf32>
  %30 = tensor.empty(%24) : tensor<?x2048xf32>
  %31 = linalg.fill ins(%cst : f32) outs(%30 : tensor<?x2048xf32>) -> tensor<?x2048xf32>
  %32 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%28, %29 : tensor<?x?xf32>, tensor<2048x?xf32>) outs(%31 : tensor<?x2048xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %33 = arith.mulf %in, %in_0 : f32
    %34 = arith.addf %out, %33 : f32
    linalg.yield %34 : f32
  } -> tensor<?x2048xf32>
  iree_tensor_ext.dispatch.tensor.store %32, %27, offsets = [0, 0], sizes = [%24, 2048], strides = [1, 1] : tensor<?x2048xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x2080xf32>>{%24}
  return
}

// -----// IR Dump After MaterializeEncodingIntoPaddingPass (iree-codegen-materialize-encoding-into-padding) //----- //
func.func @test_dispatch_1_matmul_like_Dx4096x2048_f32() {
  %cst = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %c32_i64 = arith.constant 32 : i64
  %0 = hal.interface.constant.load layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
  %2 = arith.extui %1 : i32 to i64
  %3 = arith.shli %2, %c32_i64 : i64
  %4 = arith.extui %0 : i32 to i64
  %5 = arith.ori %4, %3 : i64
  %6 = arith.index_castui %5 : i64 to index
  %7 = util.assume.int %6<umin = 0, umax = 9007199254740991> : index
  %8 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x2048xf32>>
  %9 = iree_tensor_ext.dispatch.workload.ordinal %7, 0 : index
  %10 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2080xf32>>{%9}
  %11 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%9}
  %12 = iree_tensor_ext.dispatch.tensor.load %10, offsets = [0, 0], sizes = [%9, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2080xf32>>{%9} -> tensor<?x2048xf32>
  %13 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0], sizes = [4096, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x2048xf32>> -> tensor<4096x2048xf32>
  %14 = tensor.empty(%9) : tensor<?x4096xf32>
  %15 = linalg.fill ins(%cst : f32) outs(%14 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12, %13 : tensor<?x2048xf32>, tensor<4096x2048xf32>) outs(%15 : tensor<?x4096xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %17 = arith.mulf %in, %in_0 : f32
    %18 = arith.addf %out, %17 : f32
    linalg.yield %18 : f32
  } -> tensor<?x4096xf32>
  iree_tensor_ext.dispatch.tensor.store %16, %11, offsets = [0, 0], sizes = [%9, 4096], strides = [1, 1] : tensor<?x4096xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%9}
  return
}

jtuyls avatar Jun 03 '25 16:06 jtuyls

So we only pad the LHS?

  %12 = iree_tensor_ext.dispatch.tensor.load %10, offsets = [0, 0], sizes = [%9, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2080xf32>>{%9} -> tensor<?x2048xf32>
  %13 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0], sizes = [4096, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x2048xf32>> -> tensor<4096x2048xf32>

  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12, %13 : tensor<?x2048xf32>, tensor<4096x2048xf32>) outs(%15 : tensor<?x4096xf32>) {

Why is the RHS left as-is?

kuhar avatar Jun 03 '25 18:06 kuhar

I think if the RHS is the weight then we will need a separate dispatch to add the padding to it?

MaheshRavishankar avatar Jun 03 '25 20:06 MaheshRavishankar

So we only pad the LHS?

%12 = iree_tensor_ext.dispatch.tensor.load %10, offsets = [0, 0], sizes = [%9, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2080xf32>>{%9} -> tensor<?x2048xf32> %13 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0], sizes = [4096, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x2048xf32>> -> tensor<4096x2048xf32>

%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12, %13 : tensor<?x2048xf32>, tensor<4096x2048xf32>) outs(%15 : tensor<?x4096xf32>) { Why is the RHS left as-is?

I tried that as well by 'manually' padding the rhs, but it performed worse: 7.28 ms, compared with 6.21ms for no padding and 6.20ms for only padding lhs.

// -----// IR Dump After MaterializeEncodingIntoPaddingPass (iree-codegen-materialize-encoding-into-padding) //----- //
func.func @test_dispatch_1_matmul_like_Dx4096x2048_f32() {
  %cst = arith.constant 0.000000e+00 : f32
  %c0 = arith.constant 0 : index
  %c32_i64 = arith.constant 32 : i64
  %0 = hal.interface.constant.load layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
  %2 = arith.extui %1 : i32 to i64
  %3 = arith.shli %2, %c32_i64 : i64
  %4 = arith.extui %0 : i32 to i64
  %5 = arith.ori %4, %3 : i64
  %6 = arith.index_castui %5 : i64 to index
  %7 = util.assume.int %6<umin = 0, umax = 9007199254740991> : index
  %8 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x2080xf32>>
  %9 = iree_tensor_ext.dispatch.workload.ordinal %7, 0 : index
  %10 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2080xf32>>{%9}
  %11 = hal.interface.binding.subspan layout(<constants = 2, bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%9}
  %12 = iree_tensor_ext.dispatch.tensor.load %10, offsets = [0, 0], sizes = [%9, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2080xf32>>{%9} -> tensor<?x2048xf32>
  %13 = tensor.empty(%9) : tensor<?x4096xf32>
  %14 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0], sizes = [4096, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x2080xf32>> -> tensor<4096x2048xf32>
  %15 = linalg.fill ins(%cst : f32) outs(%13 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12, %14 : tensor<?x2048xf32>, tensor<4096x2048xf32>) outs(%15 : tensor<?x4096xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %17 = arith.mulf %in, %in_0 : f32
    %18 = arith.addf %out, %17 : f32
    linalg.yield %18 : f32
  } -> tensor<?x4096xf32>
  iree_tensor_ext.dispatch.tensor.store %16, %11, offsets = [0, 0], sizes = [%9, 4096], strides = [1, 1] : tensor<?x4096xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x4096xf32>>{%9}
  return
}

jtuyls avatar Jun 03 '25 21:06 jtuyls

OK let me run some benchmarks on my own, this is not what I saw in my past experiments

kuhar avatar Jun 03 '25 22:06 kuhar

@kuhar, I looked at this with @MaheshRavishankar and we noticed that no mfma operations are being generated for the example as 'M' can be fully dynamic. After specifying that M is a multiple of 16, we do see them and this changes the numbers with padding resulting in +-8% better latency:

IR Baseline runtime (ms) Pad runtime (ms)
Matmul_transpose_b 2.42 2.23

EDIT: Note that we still need to look at why we don't see this improvement in the llama model yet.

The new matmul_transpose_b input IR:

util.func @test(%arg0 : tensor<?x?xf32>, %arg1 : tensor<2048x?xf32>, %arg2 : tensor<4096x2048xf32>) -> tensor<?x4096xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %cst = arith.constant 0.0 : f32
  %M = tensor.dim %arg0, %c0 : tensor<?x?xf32>
  %MA = util.assume.int %M<umin = 16, umax = 524160, udiv = 16> : index
  %K1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
  %0 = tensor.empty(%MA) : tensor<?x2048xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x2048xf32>) -> tensor<?x2048xf32>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<2048x?xf32>) outs(%1 : tensor<?x2048xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %42 = arith.mulf %in, %in_0 : f32
    %43 = arith.addf %out, %42 : f32
    linalg.yield %43 : f32
  } -> tensor<?x2048xf32>
  %extract = tensor.extract_slice %arg2[0, 0] [4096, 2048] [1, 1] : tensor<4096x2048xf32> to tensor<4096x2048xf32>
  %3 = tensor.empty(%MA) : tensor<?x4096xf32>
  %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<?x4096xf32>) -> tensor<?x4096xf32>
  %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2, %extract : tensor<?x2048xf32>, tensor<4096x2048xf32>) outs(%4 : tensor<?x4096xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %42 = arith.mulf %in, %in_0 : f32
    %43 = arith.addf %out, %42 : f32
    linalg.yield %43 : f32
  } -> tensor<?x4096xf32>
  util.return %5 : tensor<?x4096xf32>
}

jtuyls avatar Jun 03 '25 22:06 jtuyls

So we only pad the LHS?

%12 = iree_tensor_ext.dispatch.tensor.load %10, offsets = [0, 0], sizes = [%9, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x2080xf32>>{%9} -> tensor<?x2048xf32> %13 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0], sizes = [4096, 2048], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4096x2048xf32>> -> tensor<4096x2048xf32>

%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%12, %13 : tensor<?x2048xf32>, tensor<4096x2048xf32>) outs(%15 : tensor<?x4096xf32>) { Why is the RHS left as-is?

To support RHS padding, it requires much more work. Mahesh and I discussed it before. It is not an easy move, and we do want to support it when we get cycles. For now, I think the scope of the work is to do partial padding (ie LHS only), and we extend the mechanism incrementally.

hanhanW avatar Jun 04 '25 02:06 hanhanW

Ah right, that's why we don't expect to see the same speedup as in my initial experiments.

kuhar avatar Jun 04 '25 14:06 kuhar

Update: I manually adjusted the base IR and IRPA weights file to pad (some of) the weight parameters as well and this results in some speedup:

IR Baseline runtime (ms) Pad runtime (ms)
LLama3 364 ms 349 ms

Notes:

  • These numbers are for prefill with sequence length 2048.
  • These numbers are with tuning as well for both (--iree-codegen-enable-default-tuning-specs=true)
  • This machine I am running on feels a bit slow compared with others I have run on before. So, be careful comparing too much with other numbers on other machines you have seen.

Breakdown and comparison of the largest matmul dispatches:

Dispatches Baseline runtime (ms) Pad runtime (ms)
dispatch_678_matmul_like_Dx128256x4096_f32 103.07 96.28 ms
dispatch_25_matmul_like_Dx4096x14336_f8_f8_f32 38.42 32.18 ms
dispatch_24_matmul_like_Dx14336x4096_f8_f8_f32 38.26 38.19 ms
dispatch_23_matmul_like_Dx14336x4096_f8_f8_f32 33.13 33.56 ms

Other dispatches have still the same latency.

cc @kuhar @MaheshRavishankar @hanhanW

jtuyls avatar Jun 06 '25 17:06 jtuyls

These numbers are with tuning as well for both (--iree-codegen-enable-default-tuning-specs=true)

Does this result in pingpong getting selected or not? If it does, then padding might interfere with the cache swizzle set by the pingpong code.

kuhar avatar Jun 06 '25 19:06 kuhar

These numbers are with tuning as well for both (--iree-codegen-enable-default-tuning-specs=true)

Does this result in pingpong getting selected or not? If it does, then padding might interfere with the cache swizzle set by the pingpong code.

It is with the ping-pong scheduling.

MaheshRavishankar avatar Jun 06 '25 23:06 MaheshRavishankar

@jtuyls could you rerun this without pingpong?

kuhar avatar Jun 07 '25 01:06 kuhar

@jtuyls could you rerun this without pingpong?

@kuhar These are all the e2e numbers with and without ping-pong (--iree-codegen-enable-default-tuning-specs=true):

Model Baseline (ms) Baseline tuned (ms) Pad_LHS (ms) Pad_LHS tuned (ms) Pad_LHS_RHS (ms) Pad_LHS_RHS tuned (ms)
LLama3 375 343 380 347 366 332

From this, it doesn't look like the difference is larger when ping-pong is disabled.

Notes:

  • This is run on a different node as the numbers before as I am not able to reach that machine anymore since the weekend and the numbers on this node are better overall.
  • This is for sequence length 2048

jtuyls avatar Jun 10 '25 08:06 jtuyls

Another issue with the test IR from https://github.com/iree-org/iree/issues/20835#issuecomment-2937559669 is that the final matmul is (f32, f32) -> (f32) and won't use any of the efficient mfma instructions because of the operand types. I'm working on a new isolated benchmark for f16 -> f32 and f8 -> f32.

kuhar avatar Jun 10 '25 20:06 kuhar

Another issue with the test IR from #20835 (comment) is that the final matmul is (f32, f32) -> (f32) and won't use any of the efficient mfma instructions because of the operand types. I'm working on a new isolated benchmark for f16 -> f32 and f8 -> f32.

Yes, that's what I mentioned last week on the codegen sync and slack as well.

jtuyls avatar Jun 11 '25 07:06 jtuyls

I verified the new code generated by https://github.com/jtuyls/iree/tree/padEncodinge2e-2 and I think it works just like we wanted it to and gives the expected improvements.

I tried it on the shapes that gave us the biggest improvement in the past: https://github.com/nod-ai/playbook/issues/63#issuecomment-2593801793 using this u-benchmark that forces both LHS and RHS to come from dispatches:

!matA = tensor<512x512xf16>
!matB = tensor<14336x512xf16>
!matC = tensor<4096x512xf16>

!matLhs = tensor<512x14336xf16>
!matRhs = tensor<4096x14336xf16>

!matRes = tensor<512x4096xf32>

func.func @main(%arg0: !matA, %arg1: !matB, %arg2: !matC) -> !matRes {
  %f32 = arith.constant 0.000000e+00 : f32
  %f16 = arith.constant 0.000000e+00 : f16
  %5 = tensor.empty() : !matLhs
  %6 = linalg.fill ins(%f16 : f16) outs(%5 : !matLhs) -> !matLhs
  %lhs = linalg.matmul_transpose_b ins(%arg0, %arg1 : !matA, !matB) outs(%6 : !matLhs) -> !matLhs
  %8 = tensor.empty() : !matRhs
  %9 = linalg.fill ins(%f16 : f16) outs(%8 : !matRhs) -> !matRhs
  %rhs = linalg.matmul_transpose_b ins(%arg2, %arg1 : !matC, !matB) outs(%9 : !matRhs) -> !matRhs
  %11 = tensor.empty() : !matRes
  %12 = linalg.fill ins(%f32 : f32) outs(%11 : !matRes) -> !matRes
  %13 = linalg.matmul_transpose_b ins(%lhs, %rhs : !matLhs, !matRhs) outs(%12 : !matRes) -> !matRes
  return %13 : !matRes
}

The padding values are as expected and the only meaningful difference in the generated assembly are the strides used to address into input buffers.

I observed the following speedup:

  • Orig: 324 us
  • Padding: 262 us

The speedup is ~20%. This is in SPX so the absolute numbers are not the same as my previous results on CPX -- I think CPX can account for the remaining ~10 pp.

I will run this under a profiler and double-check the counters next.

kuhar avatar Jun 13 '25 21:06 kuhar

As to why some shapes observe higher speedup than others: it boils down to how many loads of each operand get generated per loop iteration. Padding helps most when each of LHS and RHS loads a multiple of 4 row chunks per iteration -- in that case, the memory bandwidth can increase by 4x. This usually happens when the reduction dimension is large or when the assumptions specify a large udiv value for K. In the test IR higher up the issue I only counted two loads of dwordx4 for LHS that were affected by padding, and in that case the maximum bandwidth increase was 2x and only for LHS.

Some of the previous isolated experiments that increased the logical tensor size (not just strides) could have affected tile size selection and thus the number of adjacent load instructions, which happened to benefit more from padding than with the base tile size. It's hard to tell without looking at the exact assembly -- llvm makes the final decision of what to unroll and by how much.

kuhar avatar Jun 14 '25 03:06 kuhar

Threadtrace confirms that padding helps at the level of buffer load instructions:

Image

Image

kuhar avatar Jun 16 '25 19:06 kuhar

omniperf also confirms higher cach bandwidth with padding:

Image

Image

Image

Image

Commands:

rocprof-compute profile --name <name> -d 3 --no-roof --  ~/iree/relass/tools/iree-benchmark-module --device=hip://0 --device_allocator=caching --hip_use_streams=true --module=<name.vmfb> --benchmark_repetitions=1

rocprof-compute analyze -p workloads/<name>/MI300X_A1/ > <name>.stats.txt

kuhar avatar Jun 16 '25 20:06 kuhar