iree icon indicating copy to clipboard operation
iree copied to clipboard

[Mistral] Performance degradation with VMFB containing prefill functions of multiple batch sizes

Open pravg-amd opened this issue 8 months ago • 11 comments

What happened?

When a single vmfb containing prefill functions for multiple batch sizes (2 and 8), there is a performance degradation while running prefill function with batch size 8 when compared to running the VMFB with single prefill function of batch size 8.

This is not visible with batch sizes (4,8) etc.

prefill_bs_4_8.txt prefill_bs_2_8.txt prefill_bs_8.txt

Steps to reproduce your issue

The mlir files with prefill batch size (2,8), (4,8) and (8) are attached in this ticket.

Generate the vmfb using the following command

iree-compile prefill_bs_2_8.mlir \
    --iree-hal-target-device=hip \
    --iree-hip-target=gfx942 \
    --iree-opt-level=O3  \
    --iree-hal-indirect-command-buffers=true  \
    --iree-stream-resource-memory-model=discrete  \
    --iree-hal-memoization=true \
    -o quark_mistral_nemo.vmfb

Run the benchmark for prefill_bs8 (SharkMI300x-3)

iree-benchmark-module \
    --device=hip://2 \
    --device_allocator=caching \
    --module=quark_mistral_nemo.vmfb \
    --parameters=model=/data/Mistral-Nemo-Instruct-2407-FP8/quark_mistral_nemo.irpa \
    --function=prefill_bs8 \
    --input=8x1024xsi64 \
    --input=8xsi64 \
    --input=8x32xsi64 \
    --input=1024x2621440xf8E4M3FNUZ \
    --benchmark_repetitions=5

With prefill_bs_8.mlir / prefill_bs_4_8.mlir

BM_prefill_bs8/process_time/real_time               639 ms          639 ms            1 items_per_second=1.56409/s
BM_prefill_bs8/process_time/real_time               640 ms          640 ms            1 items_per_second=1.56353/s
BM_prefill_bs8/process_time/real_time               639 ms          639 ms            1 items_per_second=1.56485/s
BM_prefill_bs8/process_time/real_time               639 ms          640 ms            1 items_per_second=1.56472/s
BM_prefill_bs8/process_time/real_time               639 ms          640 ms            1 items_per_second=1.56405/s
BM_prefill_bs8/process_time/real_time_mean          639 ms          640 ms            5 items_per_second=1.56425/s
BM_prefill_bs8/process_time/real_time_median        639 ms          640 ms            5 items_per_second=1.56409/s
BM_prefill_bs8/process_time/real_time_stddev      0.221 ms        0.371 ms            5 items_per_second=540.428u/s
BM_prefill_bs8/process_time/real_time_cv           0.03 %          0.06 %             5 items_per_second=0.03%

With prefill_bs_2_8.mlir

BM_prefill_bs8/process_time/real_time               873 ms          873 ms            1 items_per_second=1.14508/s
BM_prefill_bs8/process_time/real_time               874 ms          875 ms            1 items_per_second=1.14362/s
BM_prefill_bs8/process_time/real_time               874 ms          874 ms            1 items_per_second=1.14405/s
BM_prefill_bs8/process_time/real_time               873 ms          874 ms            1 items_per_second=1.14492/s
BM_prefill_bs8/process_time/real_time               874 ms          874 ms            1 items_per_second=1.14404/s
BM_prefill_bs8/process_time/real_time_mean          874 ms          874 ms            5 items_per_second=1.14434/s
BM_prefill_bs8/process_time/real_time_median        874 ms          874 ms            5 items_per_second=1.14405/s
BM_prefill_bs8/process_time/real_time_stddev      0.478 ms        0.574 ms            5 items_per_second=626.199u/s

What component(s) does this issue relate to?

Compiler

Version information

IREE compiler version 3.5.0rc20250514 @ d63e15e15509784de68f1e39f86f78c980031dda

Additional context

Steps to download model and irpa files are available here.

https://gist.github.com/pravg-amd/1b9f3e3c3abcb6f2c35fdc10a09db09d

pravg-amd avatar May 16 '25 12:05 pravg-amd

Initial analysis

As part of the DeduplicateExecutables pass, the following dispatch gets changed as follows

    %1520 = flow.dispatch @prefill_bs8$async_dispatch_805::@prefill_bs8$async_dispatch_805_matmul_like_Dx131072x5120_f16xf16xf32[%1519, %13](%1519, %1518, %__hoisted_tensor_131072x5120xf16_578, %13) : (index, tensor<?x5120xf16>{%13}, tensor<131072x5120xf16>, index) -> tensor<?x131072xf16>{%13}


to

    %1520 = flow.dispatch @prefill_bs2$async_dispatch_805::@prefill_bs2$async_dispatch_805_matmul_like_Dx131072x5120_f16xf16xf32[%1519, %13](%1519, %1518, %__hoisted_tensor_131072x5120xf16_578, %13) : (index, tensor<?x5120xf16>{%13}, tensor<131072x5120xf16>, index) -> tensor<?x131072xf16>{%13}

At HAL for the prefill_bs_2_8 case, the workgroup sizes are (512, 4, z)

    %ordinal_204 = hal.executable.export.ordinal target(@module_linked::@rocm_hsaco_fb::@prefill_bs2$async_dispatch_805_matmul_like_Dx131072x5120_f16xf16xf32) : index
    %147 = arith.divsi %10, %c64 : index
    hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe_203 : !hal.executable)[%ordinal_204] workgroups([%c512, %c4, %147]) constants([%55, %57, %c-267351040_i32, %53]) bindings([
      (%transient_buffer_3 : !hal.buffer)[%c0, %51],
      (%__hoisted_tensor_131072x5120xf16 : !hal.buffer)[%c0, %c10739587072],
      (%transient_buffer : !hal.buffer)[%c0, %32]
    ]) flags("None")

At HAL for prefill_bs_4_8 case, the workgroup sizes are (1024, 2, z)

    %ordinal_204 = hal.executable.export.ordinal target(@module_linked::@rocm_hsaco_fb::@prefill_bs4$async_dispatch_805_matmul_like_Dx131072x5120_f16xf16xf32) : index
    %147 = arith.divsi %10, %c128 : index
    hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe_203 : !hal.executable)[%ordinal_204] workgroups([%c1024, %c2, %147]) constants([%55, %57, %c-267351040_i32, %53]) bindings([
      (%transient_buffer_3 : !hal.buffer)[%c0, %51],
      (%__hoisted_tensor_131072x5120xf16 : !hal.buffer)[%c0, %c10739587072],
      (%transient_buffer : !hal.buffer)[%c0, %32]
    ]) flags("None")

Verified by disabling the DeduplicateExecutables to see the performance gain, though it increases the VMFB size.

CC:: @kumardeepakamd @MaheshRavishankar @pdhirajkumarprasad

pravg-amd avatar May 16 '25 13:05 pravg-amd

Good find! This looks like a case that specialization should be able to handle - the analysis information derived from the dispatch sites should be present during executable configuration, but AFAIK today that's not really used by codegen. It'd be good to check if what's required to specialize is there (--iree-hal-dump-executable-sources-to= should show it). This same situation would arise if a single input function dispatched with different sizes, and this particular case of globbing things together just happens to definitely show it.

benvanik avatar May 16 '25 15:05 benvanik

(also, great triage! thanks for digging in!)

benvanik avatar May 16 '25 15:05 benvanik

Following packages are required to be installed to generate irpa file:

torch
pytest

amd-vivekag avatar May 16 '25 15:05 amd-vivekag

I have a smaller reproducer for the above issue. Thanks @pashu123 for helping in debugging this.

module {
    
  func.func @prefill_bs2(%arg0: !torch.vtensor<[2, ?,5120],f16>, %arg1: !torch.vtensor<[5120,131072],f16>) -> !torch.vtensor<[?,131072],f16> attributes {torch.assume_strict_symbolic_shapes} {
      %965 = torch.symbolic_int "s1" {min_val = 2, max_val = 4095} : !torch.int
      torch.bind_symbolic_shape %arg0, [%965], affine_map<()[s0] -> (2, s0 * 32)> : !torch.vtensor<[2,?, 5120],f16>
      %int2 = torch.constant.int 2
      %int1 = torch.constant.int 1
      %971 = torch.aten.size.int %arg0, %int1: !torch.vtensor<[2, ?, 5120],f16>, !torch.int -> !torch.int
      %15268 = torch.aten.mul.int %int2, %971 : !torch.int, !torch.int -> !torch.int
      %int5120_15196 = torch.constant.int 5120
      %15269 = torch.prim.ListConstruct %15268, %int5120_15196 : (!torch.int, !torch.int) -> !torch.list<int>
      %15270 = torch.aten.view %arg0, %15269 : !torch.vtensor<[2,?,5120],f16>, !torch.list<int> -> !torch.vtensor<[?,5120],f16>
      torch.bind_symbolic_shape %15270, [%965], affine_map<()[s0] -> (s0 * 64, 5120)> : !torch.vtensor<[?,5120],f16>
      %15271 = torch.aten.mm %15270, %arg1 : !torch.vtensor<[?,5120],f16>, !torch.vtensor<[5120,131072],f16> -> !torch.vtensor<[?,131072],f16>
      torch.bind_symbolic_shape %15271, [%965], affine_map<()[s0] -> (s0 * 64, 131072)> : !torch.vtensor<[?,131072],f16>
      return %15271 : !torch.vtensor<[?,131072],f16>
    }

    func.func @prefill_bs8(%arg0: !torch.vtensor<[8, ?,5120],f16>, %arg1: !torch.vtensor<[5120,131072],f16>) -> !torch.vtensor<[?,131072],f16> attributes {torch.assume_strict_symbolic_shapes} {
      %965 = torch.symbolic_int "s1" {min_val = 2, max_val = 4095} : !torch.int
      torch.bind_symbolic_shape %arg0, [%965], affine_map<()[s0] -> (8, s0 * 32)> : !torch.vtensor<[8,?, 5120],f16>
      %int8_15195 = torch.constant.int 8
      %int1 = torch.constant.int 1
      %971 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[8, ?, 5120],f16>, !torch.int -> !torch.int
      %15268 = torch.aten.mul.int %int8_15195, %971 : !torch.int, !torch.int -> !torch.int
      %int5120_15196 = torch.constant.int 5120
      %15269 = torch.prim.ListConstruct %15268, %int5120_15196 : (!torch.int, !torch.int) -> !torch.list<int>
      %15270 = torch.aten.view %arg0, %15269 : !torch.vtensor<[8,?,5120],f16>, !torch.list<int> -> !torch.vtensor<[?,5120],f16>
      torch.bind_symbolic_shape %15270, [%965], affine_map<()[s0] -> (s0 * 256, 5120)> : !torch.vtensor<[?,5120],f16>
      %15271 = torch.aten.mm %15270, %arg1 : !torch.vtensor<[?,5120],f16>, !torch.vtensor<[5120,131072],f16> -> !torch.vtensor<[?,131072],f16>
      torch.bind_symbolic_shape %15271, [%965], affine_map<()[s0] -> (s0 * 256, 131072)> : !torch.vtensor<[?,131072],f16>
      return %15271 : !torch.vtensor<[?,131072],f16>
    }
}

The lowering_config generated for the individual functions without DeduplicateExecutables have the following workgroup sizes.

For the function prefill_bs2

attrs =  {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [0, 1], reduction = [0, 0, 0, 64], subgroup_m_count = 1 : i64, subgroup_n_count = 4 : i64, workgroup = [1, 16, 256, 0]}>}

For the function prefill_bs8

attrs =  {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [0, 1], reduction = [0, 0, 0, 64], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [1, 64, 128, 0]}>}

The lowering config with 'DeduplicateExecutables` is same as that of the prefill_bs2 for both the functions, causing the regression in the performance.

attrs =  {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [0, 1], reduction = [0, 0, 0, 64], subgroup_m_count = 1 : i64, subgroup_n_count = 4 : i64, workgroup = [1, 16, 256, 0]}>}

Command to run the benchmark for the above IR

iree-benchmark-module --device=hip://11 --module=test.vmfb --parameters=model=/shark-dev/Mistral-Nemo-Instruct-2407-FP8/quark_mistral_nemo.irpa --function=prefill_bs8 --input=8x2048x5120xf16 --input=5120x131072xf16 --benchmark_repetitions=5 --device_allocator=caching                                                                                                                                                          

Benchmarks

Without DeduplicateExecutables

-------------------------------------------------------------------------------------------------------
Benchmark                                             Time             CPU   Iterations UserCounters...
-------------------------------------------------------------------------------------------------------

BM_prefill_bs8/process_time/real_time               140 ms          140 ms            5 items_per_second=7.1644/s
BM_prefill_bs8/process_time/real_time               140 ms          140 ms            5 items_per_second=7.15495/s
BM_prefill_bs8/process_time/real_time               140 ms          140 ms            5 items_per_second=7.15178/s
BM_prefill_bs8/process_time/real_time               140 ms          140 ms            5 items_per_second=7.14658/s
BM_prefill_bs8/process_time/real_time               140 ms          140 ms            5 items_per_second=7.13627/s
BM_prefill_bs8/process_time/real_time_mean          140 ms          140 ms            5 items_per_second=7.1508/s
BM_prefill_bs8/process_time/real_time_median        140 ms          140 ms            5 items_per_second=7.15178/s
BM_prefill_bs8/process_time/real_time_stddev      0.203 ms        0.191 ms            5 items_per_second=0.0103955/s
BM_prefill_bs8/process_time/real_time_cv           0.15 %          0.14 %             5 items_per_second=0.15%

With DeduplicateExecutables

-------------------------------------------------------------------------------------------------------
Benchmark                                             Time             CPU   Iterations UserCounters...
-------------------------------------------------------------------------------------------------------

BM_prefill_bs8/process_time/real_time               509 ms          509 ms            1 items_per_second=1.96303/s
BM_prefill_bs8/process_time/real_time               508 ms          508 ms            1 items_per_second=1.9681/s
BM_prefill_bs8/process_time/real_time               508 ms          508 ms            1 items_per_second=1.96712/s
BM_prefill_bs8/process_time/real_time               508 ms          508 ms            1 items_per_second=1.96777/s
BM_prefill_bs8/process_time/real_time               509 ms          509 ms            1 items_per_second=1.96553/s
BM_prefill_bs8/process_time/real_time_mean          509 ms          509 ms            5 items_per_second=1.96631/s
BM_prefill_bs8/process_time/real_time_median        508 ms          508 ms            5 items_per_second=1.96712/s
BM_prefill_bs8/process_time/real_time_stddev      0.539 ms        0.542 ms            5 items_per_second=2.08072m/s
BM_prefill_bs8/process_time/real_time_cv           0.11 %          0.11 %             5 items_per_second=0.11%

Dump with --iree-hal-dump-executable-sources-to

With DeduplicateExecutables pass https://gist.github.com/pravg-amd/003c9268b24200a7db2c06b938ee3285#file-with_deduplicate-mlir

Without DeduplicateExecutables pass for bs2 and bs8

https://gist.github.com/pravg-amd/003c9268b24200a7db2c06b938ee3285#file-without_dedulicate_bs8-mlir https://gist.github.com/pravg-amd/003c9268b24200a7db2c06b938ee3285#file-without_deduplicate_bs2-mlir

@qedawkins I see you are working on specialization (https://github.com/iree-org/iree/pull/20771). Is the above issue related to this? If so, would it be handled in further patches?

@benvanik @kumardeepakamd @pashu123 @MaheshRavishankar @pdhirajkumarprasad

pravg-amd avatar May 27 '25 08:05 pravg-amd

Meant to respond to this last week. Yes, specialization work should help here but might not close the whole gap. Codegen can specialize by callsite, but once we've deduplicated executables like this, we'll be relying on host side CSE to recover uniformity across queries per batch-size entry point. I don't know if we have the host side range annotations needed to make this possible today though.

In general, this issue is part of a larger tradeoff between space (aggressive deduplication) and performance, but hopefully there isn't much real performance cost.

qedawkins avatar May 27 '25 17:05 qedawkins

It's also a compile time tradeoff - compiling 100 executables is way faster than compiling 1500. When we specialize we have to be certain we're doing it for meaningful reasons.

It looks like the information is all available:

        // bs2
        %3:2 = util.assume.int 
            %1<umin = 64, umax = 262080, udiv = 64>, 
            %2<umin = 64, umax = 262080, udiv = 64>
          : index, index
        // bs8
        %3:2 = util.assume.int 
            %1<umin = 256, umax = 1048320, udiv = 256>, 
            %2<umin = 256, umax = 1048320, udiv = 256>
          : index, index
        // combined
        %3:2 = util.assume.int 
            %1[<umin = 64, umax = 262080, udiv = 64>, <umin = 256, umax = 1048320, udiv = 256>], 
            %2[<umin = 64, umax = 262080, udiv = 64>, <umin = 256, umax = 1048320, udiv = 256>]
          : index, index

So all the same information is there in the deduplicated case and it even is nicely baked out as the uniqued set of potential values (the umin is 64 for both %1 and %2, or the umin is 256 for both %1 and %2).

In my mind the first step of specialization is (conceptually) changing the lowering config API to return a list of configs and the corresponding assume sets for each. E.g. here I'd expect it to return two (the unique ones with different subgroup sizes posted above for the non-deduplicated executables). Whenever more than one is returned we'd stamp out a new specialized function with the assume op trimmed to just that set of values. The goal is to avoid compiling to the same thing from two routes - so if we emit two different specializations and they both produce the same exact ISA instructions it means the lowering config is overspecified. It'll still happen, but it's a useful metric to track (as it's directly tried to compile time). If we can deduplicate functions eagerly as we lower to LLVM IR we'll help prevent full compilations down and relying on the linker to dedupe.

benvanik avatar May 27 '25 18:05 benvanik

I was trying to point out a different issue where I'm not sure if the host has enough information to optimize away the entry point calculation. It's maybe fine here because batch sizes are static per function and still split by callsite on the executable, but we need to make sure the callsites still have/use the assumes that the range analysis source the executable annotations from after we've specialized.

@matmul {
  util.assume.int M = [2 -> 4096 += 2] or [8 -> 16384 += 8]
}
@func_batch_size_2 {
  util.assume.int M = [2 -> 4096 += 2]
  call @matmul [2*M, N1, K1]
}
@func_batch_size_8 {
  util.assume.int M = [8 -> 16384 += 8]
  call @matmul [8*M, N1, K1]
}

In my mind the first step of specialization is (conceptually) changing the lowering config API to return a list of configs and the corresponding assume sets for each. E.g. here I'd expect it to return two (the unique ones with different subgroup sizes posted above for the non-deduplicated executables).

Yeah I was planning something like this as a step 2 (first #20771 adds an attribute to do the forking). The first step needed here is changing all of the lowering config logic to stop querying linalgOp.getStaticLoopRanges and instead load up the integer range analysis and query iteration spaces in terms of assume ranges (which hopefully just walks up and finds the assume sets). Then the existing logic can stay mostly the same and just return a list of configs. Also probably need to make #20771 smarter so it culls specializations based on the list rather than just the union.

qedawkins avatar May 27 '25 19:05 qedawkins

The assume ops in the executable are derived from the host information, so by definition it has the information. Whether we have the right folders or not is another question :)

benvanik avatar May 27 '25 19:05 benvanik

Ok just checked and util.assume.int ops aren't dropped until here which is more than late enough. I was just imagining problems then modulo missing folders.

qedawkins avatar May 27 '25 19:05 qedawkins

My hope (having looked at your PR) is that we can fold the arith AndI/etc ops - if not, we'll need to have some ops we can fold (util.int.whatever_the_arithmetic_is_doing) emitted instead

benvanik avatar May 27 '25 19:05 benvanik

@qedawkins do we have any update on this?

pdhirajkumarprasad avatar Jun 27 '25 04:06 pdhirajkumarprasad

I dont think we need to revisit after the specialization work lands. That might fix the issue.

MaheshRavishankar avatar Jun 27 '25 18:06 MaheshRavishankar