iree icon indicating copy to clipboard operation
iree copied to clipboard

Remaining work for enabling direct gather to LDS on MI-350

Open krzysz00 opened this issue 4 months ago • 3 comments

  • [x] :hourglass: In Review: Define an iree_{codegen,gpu}.coalesced_gather_dma op. This op is meant to go inside the in_parallel region of an scf.forall or the like. It takes as input a tensor to gather from, a thread-level tile of indices to gather at, and a subgroup-level output to gather in to (and returns no results - that subgroup-level tile is a shared out)
    • [x] ParallelCombiningOpInterface to support multiple kinds of terminators.
  • [x] Audit the code for places that assume that parallel_insert_slice is the only possible terminator for a forall by doing a cast<> - change those to dyn_cast<> + failing if there's no match unless there's obviously a better approach here
  • [x] Implement tiling to threads for iree_linalg_ext.gather to use the coalesced DMA gather defined above
  • [x] Implement a lowering from this coalesced DMA op to amdgpu.gather_to_lds on memrefs, including a fallback path for when the gather is not applicable
  • [ ] (In order to allow the above, we may need to ensure vectorization works for computing the gather index. This may involve teaching affine.[de]linearize_index to accept vector inputs and return vector outputs)
  • [x] create e2e tests to ensure the above pieces works together within the framework.
  • [x] Ensure copies that we want to lower to DMAs (per lowering config, currently) also run down this flow
  • [ ] Audit the previous flow for generating DMAs to see how much of it is dead code

Side note: if all this becomes easier if amdgpu.gather_to_lds can have tensor semantics, that's fine be me (the bufferization would just imply memory spaces or check for them)

Side side note: design doc

krzysz00 avatar Aug 26 '25 21:08 krzysz00

@qedawkins 's sketch of the setup

// Block level
%0 = iree_linalg_ext.gather
  input = %input : tensor<M x K>
  indices = %indices : tensor<mTile x kTile>
  output = %dest : tensor<mTile x kTile>

// Subgroup level
scf.forall (%flat_subgroup_id) {
  %m_id, %k_id = affine.delinearize_index %flat_subgroup_id into (mTile, kTile)
  %indices_slice = tensor.extract_slice %input [%m_id, %k_id] [mTile, kTile]
  %dest_slice = tensor.extract_slice %dest [%m_id, %k_id] [mTile, kTile]
  %0 = iree_linalg_ext.gather
    input = %input : tensor<M x K>
    indices = %indices_slice : tensor<mTile x kTile>
    output = %dest_slice : tensor<mTile x kTile>
  scf.forall.in_parallel {
    tensor.parallel_insert_slice %0 into %dest
  }
}

// Lane level
scf.forall (%flat_subgroup_id) {
  %m_id, %k_id = affine.delinearize_index %flat_subgroup_id into (mTile, kTile)
  %indices_slice = tensor.extract_slice %input [%m_id, %k_id] [mTile, kTile]
  %dest_slice = tensor.extract_slice %dest [%m_id, %k_id] [mTile, kTile]
  %0 = scf.forall (%flat_thread_id) shared_outs(%shared_dest_slice = %dest_slice) {
    %m_id, %k_id = affine.delinearize_index %flat_subgroup_id into (mTile, kTile)
    %indices_thread_slice = tensor.extract_slice %indices_slice [%m_id, %k_id] [m, k]
    scf.forall.in_parallel {
      iree_gpu.coasceled_gather_dma
        input = %input : tensor<M x K>
        indices = %indices_thread_slice : tensor<m x k>
        output = %shared_dest_slice : tensor<mTile x kTile>
    }
  }
  scf.forall.in_parallel {
    tensor.parallel_insert_slice %0 into %dest
  }
}

krzysz00 avatar Aug 26 '25 21:08 krzysz00

Design doc: https://gist.github.com/lialan/97900b4e3a54ec49b2a75bd27fa0bfe4

lialan avatar Aug 27 '25 16:08 lialan

Notes from meeting:

  • If the gather tensor actually spans into the inner dimensions of the source (that is, you're not gathering tensor slices) some non-trivial analysis will be needed to prove that we can turn this into a gather_to_lds (since said gather only deals in 32- or 128-bit chunks)
  • This is best avoided by not doing an inner dimension gather (at which point, post-tiling, you have a weird species of copy) because then you know you have said contiguous slices

krzysz00 avatar Aug 28 '25 15:08 krzysz00