iree icon indicating copy to clipboard operation
iree copied to clipboard

[stream] eliding async slices

Open raikonenfnu opened this issue 1 year ago • 10 comments

Request description

For our use case of updating KV-Cache during a decode step, we needed to slice a global with shape tensor<1x4096x32x128xf16> into tensor<1xstepx32x128xf16>. This sliced attention would be used as an input into our forward/decode step, and at every step we would generate a new cache entry with shape tensor<1x1x32x128xf16>, which we would like to insert onto the main global at the entry of step+1 i.e k_cache_layer_0[0:1, step+1, 0:32, 0:128] = new_entry.

The current issue we have with this is, in order to slice a global to use as input, we are currently lowering these async.slice as cmd.Copy. This would mean we'd need to do extra allocations of the size of the entire KV_cache at the current time step. This is OK at smaller steps, but as context lengths grow this makes decode speed unbearable.

What we have right now looks something like (load global -> slice -> dispatch):

 %_global_k_caches.layer_idx.0 = util.global.load @_global_k_caches.layer_idx.0 : !stream.resource<variable>
%results:65, %result_timepoint = stream.async.execute await(%19) => with(%_global_k_caches.layer_idx.0 as %arg1: !stream.resource<variable>{%c33546240}, .. {
%23:64 = stream.async.concurrent with(%arg1 as %arg869: !stream.resource<variable>{%c33546240}, %arg2 as %arg870: !stream.resource<variable>{%c33546240}, .. {
  %637 = stream.async.slice %arg869[%c0 to %1] : !stream.resource<variable>    {%c33546240} -> !stream.resource<variable>{%1}
  stream.yield %637,...
}
%24:66 = stream.async.concurrent with(%23#0 as %arg869 {
  %637 = stream.async.dispatch @run_forward_dispatch_0::@run_forward_dispatch_0_transpose_1xDx32x128_f16[%_global_seq_step.global](%arg869[%c0 to %1 for %1], %_global_seq_step.global) : (!stream.resource<variable>{%1}, index) -> !stream.resource<transient>{%1}
}
}

We can also look at the full IR for more context. Specifically look into the run_forward function.

And ideally it should look more like (load global -> subview ->dispatch):

 %_global_k_caches.layer_idx.0 = util.global.load @_global_k_caches.layer_idx.0 : !stream.resource<variable>
%results:65, %result_timepoint = stream.async.execute await(%19) => with(%_global_k_caches.layer_idx.0 as %arg1: !stream.resource<variable>{%c33546240}, .. {
%23:64 = stream.async.concurrent with(%arg1 as %arg869: !stream.resource<variable>{%c33546240}, %arg2 as %arg870: !stream.resource<variable>{%c33546240}, .. {
  %637 = stream.async.subview %arg869[%c0 to %1] : !stream.resource<variable>    {%c33546240} -> !stream.resource<variable>{%1}
  stream.yield %637,...
}
%24:66 = stream.async.concurrent with(%23#0 as %arg869 {
  %637 = stream.async.dispatch @run_forward_dispatch_0::@run_forward_dispatch_0_transpose_1xDx32x128_f16[%_global_seq_step.global](%arg869[%c0 to %1 for %1], %_global_seq_step.global) : (!stream.resource<variable>{%1}, index) -> !stream.resource<transient>{%1}
}
}

The main difference we want between async.slice and subview is async.slice lowers into a copy of the slice/range of the source, while we'd like subview to lower into aliasing of the slice/range of the source.

A solution proposed by @benvanik was to do DFX slice ranges analysis and fold async.slices into subviews iff the reads and writes ranges are independent of each other. Through discord discussions, we have agreed on introducing this pass into existing ElideAsyncCopies.cpp.

I have came up with a simple repro to iterate on:

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx1100", ukernels = "argmax"}>
#map = affine_map<()[s0] -> (s0 + 1)>
#device_target_rocm = #hal.device.target<"rocm", {executable_targets = [#executable_target_rocm_hsaco_fb], legacy_sync}>
module {
  module @state_update attributes {hal.device.targets = [#device_target_rocm]} {
    util.global private mutable @_global_seq_step.global {noinline} = 0 : index
    util.global private mutable @_global_k_caches.layer_idx.0__timepoint = #stream.timepoint<immediate> : !stream.timepoint
    util.global private mutable @_global_k_caches.layer_idx.0 : !stream.resource<variable>
    util.global private mutable @_global_k_caches.layer_idx.1 : !stream.resource<variable>
    util.global private mutable @_global_v_caches.layer_idx.0 : !stream.resource<variable>
    stream.executable private @run_forward_dispatch_804 {
      stream.executable.export public @run_forward_dispatch_804_slow_memcpy workgroups(%arg0: index, %arg1: index) -> (index, index, index) {
        %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg0, %arg1
        stream.return %x, %y, %z : index, index, index
      }
      builtin.module {
        func.func @run_forward_dispatch_804_slow_memcpy(%arg0: index, %arg1: !stream.binding, %arg2: index, %arg3: !stream.binding) {
          %c-1_i64 = arith.constant -1 : i64
          %c0_i64 = arith.constant 0 : i64
          %c0 = arith.constant 0 : index
          %0 = stream.binding.subspan %arg3[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<32x128xf16>>
          %1 = flow.dispatch.workload.ordinal %arg2, 1 : index
          %2 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<1x32x?x128xf16>>{%1}
          %3 = flow.dispatch.workload.ordinal %arg0, 0 : index
          %4 = affine.apply #map()[%3]
          %5 = arith.index_cast %4 : index to i64
          %6 = arith.addi %5, %c-1_i64 : i64
          %7 = arith.cmpi slt, %6, %c0_i64 : i64
          %8 = arith.select %7, %c0_i64, %6 : i64
          %9 = arith.cmpi sgt, %8, %5 : i64
          %10 = arith.select %9, %5, %8 : i64
          %11 = arith.index_cast %10 : i64 to index
          %12 = flow.dispatch.tensor.load %2, offsets = [0, 0, %11, 0], sizes = [1, 32, 1, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x32x?x128xf16>>{%1} -> tensor<32x128xf16>
          flow.dispatch.tensor.store %12, %0, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : tensor<32x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<32x128xf16>>
          return
        }
      }
    }
    util.func public @run_forward(%arg0: !hal.buffer_view) attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @run_forward(%input0: tensor<1x1xi64>) -> ()"}} {
      %_global_seq_step.global = util.global.load @_global_seq_step.global : index
      %_global_k_caches.layer_idx.0__timepoint = util.global.load @_global_k_caches.layer_idx.0__timepoint : !stream.timepoint
      %_global_k_caches.layer_idx.0 = util.global.load @_global_k_caches.layer_idx.0 : !stream.resource<variable>
      %_global_k_caches.layer_idx.0__timepoint_0 = util.global.load @_global_k_caches.layer_idx.0__timepoint : !stream.timepoint
      %_global_v_caches.layer_idx.0 = util.global.load @_global_v_caches.layer_idx.0 : !stream.resource<variable>
      %_global_k_caches.layer_idx.0__timepoint_1 = util.global.load @_global_k_caches.layer_idx.0__timepoint : !stream.timepoint
      %_global_k_caches.layer_idx.1 = util.global.load @_global_k_caches.layer_idx.1 : !stream.resource<variable>
      %c0 = arith.constant 0 : index
      %c8192 = arith.constant 8192 : index
      %c33546240 = arith.constant 33546240 : index
      %0 = arith.muli %_global_seq_step.global, %c8192 : index
      %1 = affine.apply #map()[%_global_seq_step.global]
      %2 = arith.muli %1, %c8192 : index
      %3 = arith.addi %0, %c8192 : index
      %4 = stream.timepoint.join max(%_global_k_caches.layer_idx.0__timepoint, %_global_k_caches.layer_idx.0__timepoint_0, %_global_k_caches.layer_idx.0__timepoint_1) => !stream.timepoint
      %results:3, %result_timepoint = stream.async.execute await(%4) => with(%_global_k_caches.layer_idx.0 as %arg1: !stream.resource<variable>{%c33546240}, %_global_v_caches.layer_idx.0 as %arg2: !stream.resource<variable>{%c33546240}, %_global_k_caches.layer_idx.1 as %arg3: !stream.resource<variable>{%c33546240}) -> (%_global_k_caches.layer_idx.0{%c33546240}, %_global_v_caches.layer_idx.0{%c33546240}, %_global_k_caches.layer_idx.1{%c33546240}) {
        %6:3 = stream.async.concurrent with(%arg1 as %arg4: !stream.resource<variable>{%c33546240}, %arg2 as %arg5: !stream.resource<variable>{%c33546240}, %arg3 as %arg6: !stream.resource<variable>{%c33546240}) -> (!stream.resource<variable>{%0}, !stream.resource<variable>{%0}, !stream.resource<variable>{%0}) {
          %9 = stream.async.slice %arg4[%c0 to %0] : !stream.resource<variable>{%c33546240} -> !stream.resource<variable>{%0}
          %10 = stream.async.slice %arg5[%c0 to %0] : !stream.resource<variable>{%c33546240} -> !stream.resource<variable>{%0}
          %11 = stream.async.slice %arg6[%c0 to %0] : !stream.resource<variable>{%c33546240} -> !stream.resource<variable>{%0}
          stream.yield %9, %10, %11 : !stream.resource<variable>{%0}, !stream.resource<variable>{%0}, !stream.resource<variable>{%0}
        }
        %7:3 = stream.async.concurrent with(%6#0 as %arg4: !stream.resource<variable>{%0}, %6#1 as %arg5: !stream.resource<variable>{%0}, %6#2 as %arg6: !stream.resource<variable>{%0}) -> (!stream.resource<transient>{%c8192}, !stream.resource<transient>{%c8192}, !stream.resource<transient>{%c8192}) {
          %9 = stream.async.dispatch @run_forward_dispatch_804::@run_forward_dispatch_804_slow_memcpy[%_global_seq_step.global, %1](%_global_seq_step.global, %arg4[%c0 to %0 for %0], %1) : (index, !stream.resource<variable>{%0}, index) -> !stream.resource<transient>{%c8192}
          %10 = stream.async.dispatch @run_forward_dispatch_804::@run_forward_dispatch_804_slow_memcpy[%_global_seq_step.global, %1](%_global_seq_step.global, %arg5[%c0 to %0 for %0], %1) : (index, !stream.resource<variable>{%0}, index) -> !stream.resource<transient>{%c8192}
          %11 = stream.async.dispatch @run_forward_dispatch_804::@run_forward_dispatch_804_slow_memcpy[%_global_seq_step.global, %1](%_global_seq_step.global, %arg6[%c0 to %0 for %0], %1) : (index, !stream.resource<variable>{%0}, index) -> !stream.resource<transient>{%c8192}
          stream.yield %9, %10, %11 : !stream.resource<transient>{%c8192}, !stream.resource<transient>{%c8192}, !stream.resource<transient>{%c8192}
        }
        %8:3 = stream.async.concurrent with(%arg1 as %arg4: !stream.resource<variable>{%c33546240}, %7#0 as %arg5: !stream.resource<transient>{%c8192}, %arg2 as %arg6: !stream.resource<variable>{%c33546240}, %7#1 as %arg7: !stream.resource<transient>{%c8192}, %arg3 as %arg8: !stream.resource<variable>{%c33546240}, %7#2 as %arg9: !stream.resource<transient>{%c8192}) -> (%arg1{%c33546240}, %arg2{%c33546240}, %arg3{%c33546240}) {
          %9 = stream.async.update %arg5, %arg4[%0 to %3] : !stream.resource<transient>{%c8192} -> %arg4 as !stream.resource<variable>{%c33546240}
          %10 = stream.async.update %arg7, %arg6[%0 to %3] : !stream.resource<transient>{%c8192} -> %arg6 as !stream.resource<variable>{%c33546240}
          %11 = stream.async.update %arg9, %arg8[%0 to %3] : !stream.resource<transient>{%c8192} -> %arg8 as !stream.resource<variable>{%c33546240}
          stream.yield %9, %10, %11 : !stream.resource<variable>{%c33546240}, !stream.resource<variable>{%c33546240}, !stream.resource<variable>{%c33546240}
        }
        stream.yield %8#0, %8#1, %8#2 : !stream.resource<variable>{%c33546240}, !stream.resource<variable>{%c33546240}, !stream.resource<variable>{%c33546240}
      } => !stream.timepoint
      %5:3 = stream.timepoint.await %result_timepoint => %results#0, %results#1, %results#2 : !stream.resource<variable>{%c33546240}, !stream.resource<variable>{%c33546240}, !stream.resource<variable>{%c33546240}
      util.global.store %results#0, @_global_k_caches.layer_idx.0 : !stream.resource<variable>
      util.global.store %results#1, @_global_v_caches.layer_idx.0 : !stream.resource<variable>
      util.global.store %results#2, @_global_k_caches.layer_idx.1 : !stream.resource<variable>
      util.return
    }
  }
}

(gist link) Ideally once we run through ElideAsyncCopies with the new passes, we will replace all the async.slice into subview.

Once we have this small repro working, we should test it out on the model level IR:

  1. linalg level llama IR
  2. Before allocation scheduling Llama IR

To repro the async slices generation/test if the full model is working once we have the pass do:

/path/to/install/bin/iree-compile /path/to/flow_llama2_7b.mlir --iree-llvmcpu-target-cpu-features=host --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vulkan-target-triple=rdna3-unknown-linux --iree-rocm-target-chip=gfx1100 --iree-input-type=torch --iree-hal-target-backends=rocm --iree-rocm-link-bc=true --verify=true --iree-stream-resource-max-allocation-size=4294967296 --compile-to=hal --iree-hal-dump-executable-sources-to=llama_7b_shark2_dispatches --iree-rocm-enable-ukernels=argmax -o /dev/null --mlir-print-ir-before=iree-stream-schedule-allocation 2> before_alloca_llama2_7b_splitkv_from_flow.mlir

As noted before, most of this quirks are happening in the run_forward function, so when looking for the quirk be sure to look there. :)

What component(s) does this issue relate to?

MLIR

Additional context

No response

raikonenfnu avatar Mar 04 '24 06:03 raikonenfnu

thanks for the repro it's much easier to see what we need to do before scheduling execution and nesting things - ElideAsyncCopies runs very early on in the pipeline.

benvanik avatar Mar 04 '24 15:03 benvanik

your "linalg level llama IR" seems to be torch, and running it through iree-compile doesn't seem to convert out of torch? do you have compile commands that work with that?

c:/Users/Ben/Downloads/stripped_Llama_2_7b_hf_splitvk.mlir:807:13: error: failed to legalize operation 'torch.constant.int'
    %int1 = torch.constant.int 1
            ^
c:/Users/Ben/Downloads/stripped_Llama_2_7b_hf_splitvk.mlir:807:13: note: see current operation: %0 = "torch.constant.int"() <{value = 1 : i64}> : () -> !torch.int
c:/Users/Ben/Downloads/stripped_Llama_2_7b_hf_splitvk.mlir:11227:13: error: failed to legalize operation 'torch.constant.int'
    %int2 = torch.constant.int 2
            ^
c:/Users/Ben/Downloads/stripped_Llama_2_7b_hf_splitvk.mlir:11227:13: note: see current operation: %0 = "torch.constant.int"() <{value = 2 : i64}> : () -> !torch.int

benvanik avatar Mar 04 '24 15:03 benvanik

actually, I don't care - can you just post the results of an iree-compile --compile-to=flow? that's a better starting point before dealing with stream passes (the before allocation one you have is to late)

benvanik avatar Mar 04 '24 15:03 benvanik

actually, I don't care - can you just post the results of an iree-compile --compile-to=flow? that's a better starting point before dealing with stream passes (the before allocation one you have is to late)

@benvanik Sorry about that, I have updated the link, please try that one out. :)

raikonenfnu avatar Mar 04 '24 17:03 raikonenfnu

which link?

benvanik avatar Mar 04 '24 17:03 benvanik

Ah, I can see how that's confusing, I updated the linalg level llama IR link on the first post.

but here is the direct link for easier access: https://storage.googleapis.com/shark-public/stan/llama_decode_allocs/flow_llama2_7b.mlir

raikonenfnu avatar Mar 04 '24 17:03 raikonenfnu

neat, elide async copies is already getting rid of the clones on each variable update:

    %869 = arith.muli %_global_seq_step.global, %c8192 : index
    %870 = stream.async.slice %63[%c0 to %869] : !stream.resource<*>{%c33546240} -> !stream.resource<*>{%869}
   // usage of %870
    %1819 = ...
    %1885 = stream.async.clone %63 : !stream.resource<*>{%c33546240} -> !stream.resource<*>{%c33546240}
    %1886 = stream.async.update %1819, %1885[%869 to %1884] : !stream.resource<*>{%c8192} -> %1885 as !stream.resource<*>{%c33546240}

->

    %869 = arith.muli %_global_seq_step.global, %c8192 : index
    %870 = stream.async.slice %63[%c0 to %869] : !stream.resource<*>{%c33546240} -> !stream.resource<*>{%869}
    // ....
    %1883 = arith.addi %869, %c8192 : index
    %1884 = stream.async.update %1818, %63[%869 to %1883] : !stream.resource<*>{%c8192} -> %63 as !stream.resource<*>{%c33546240}

so I should be able to get rid of the slice in this case (only reads until the update).

emplace allocations fails to handle the updates - it should be able to though I think it'll need my hazard analysis to do it given the partial dynamic subrange update. I think this used to work but #15261 broke it. Once we fix that we should have far fewer copies.

benvanik avatar Mar 04 '24 18:03 benvanik

Woohoo, that was super quick, thanks Ben! :)

raikonenfnu avatar Mar 04 '24 18:03 raikonenfnu

ok tweaked emplace allocations, so now the updates are slow memcpyed into place:

    %1818 = arith.addi %869, %c8192 : index
    %1819 = stream.async.dispatch @run_forward_dispatch_804::@run_forward_dispatch_804_slow_memcpy[%_global_seq_step.global, %1019](%_global_seq_step.global, %1024[%c0 to %1020 for %1020], %1019, %63[%869 to %1818 for %c8192]) : (index, !stream.resource<*>{%1020}, index, !stream.resource<*>{%c33546240}) -> %63{%c33546240}
    %1820 = stream.async.dispatch @run_forward_dispatch_804::@run_forward_dispatch_804_slow_memcpy[%_global_seq_step.global, %1019](%_global_seq_step.global, %1027[%c0 to %1020 for %1020], %1019, %31[%869 to %1818 for %c8192]) : (index, !stream.resource<*>{%1020}, index, !stream.resource<*>{%c33546240}) -> %31{%c33546240}
    %1821 = stream.async.dispatch @run_forward_dispatch_804::@run_forward_dispatch_804_slow_memcpy[%_global_seq_step.global, %1019](%_global_seq_step.global, %1051[%c0 to %1020 for %1020], %1019, %62[%869 to %1818 for %c8192]) : (index, !stream.resource<*>{%1020}, index, !stream.resource<*>{%c33546240}) -> %62{%c33546240}
    %1822 = stream.async.dispatch @run_forward_dispatch_804::@run_forward_dispatch_804_slow_memcpy[%_global_seq_step.global, %1019](%_global_seq_step.global, %1054[%c0 to %1020 for %1020], %1019, %30[%869 to %1818 for %c8192]) : (index, !stream.resource<*>{%1020}, index, !stream.resource<*>{%c33546240}) -> %30{%c33546240}
    %1823 = stream.async.dispatch @run_forward_dispatch_804::@run_forward_dispatch_804_slow_memcpy[%_global_seq_step.global, %1019](%_global_seq_step.global, %1076[%c0 to %1020 for %1020], %1019, %51[%869 to %1818 for %c8192]) : (index, !stream.resource<*>{%1020}, index, !stream.resource<*>{%c33546240}) -> %51{%c33546240}

dispatch 804 is just

        %12 = flow.dispatch.tensor.load %2, offsets = [0, 0, %11, 0], sizes = [1, 32, 1, 128], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x32x?x128xf16>>{%1} -> tensor<32x128xf16>
        flow.dispatch.tensor.store %12, %0, offsets = [0, 0], sizes = [32, 128], strides = [1, 1] : tensor<32x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<32x128xf16>>

so if we can get rid of that we'd be writing from the producers directly into the variables. of course, those producers are also slow_memcpy ops:

        %4 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [32, %0, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<32x?x128xf16>>{%0} -> tensor<32x?x128xf16>
        flow.dispatch.tensor.store %4, %3, offsets = [0, 0, 0, 0], sizes = [1, 32, %0, 128], strides = [1, 1, 1, 1] : tensor<32x?x128xf16> -> !flow.dispatch.tensor<readwrite:tensor<1x32x?x128xf16>>{%1}

so we really need to kill those - we should ideally have 0 "slow_memcpy" in the program - until we do, we'll have copies. since the values are produced all in order once we remove those we should drop our transient memory usage a lot as we'll be reading and writing each variable in turn and only need as much transient memory to keep the working set live.

benvanik avatar Mar 04 '24 18:03 benvanik

so we really need to kill those - we should ideally have 0 "slow_memcpy" in the program - until we do, we'll have copies.

Yeap! I am currently looking into what we can do on model side and fusion side to get rid of the remaining dispatches that essentially generate/requires these large allocations/copies.

raikonenfnu avatar Mar 04 '24 18:03 raikonenfnu