iree
iree copied to clipboard
Remaining work for enabling direct gather to LDS on MI-350
- [x] :hourglass: In Review: Define an
iree_{codegen,gpu}.coalesced_gather_dmaop. This op is meant to go inside thein_parallelregion of anscf.forallor the like. It takes as input a tensor to gather from, a thread-level tile of indices to gather at, and a subgroup-level output to gather in to (and returns no results - that subgroup-level tile is a shared out)- [x]
ParallelCombiningOpInterfaceto support multiple kinds of terminators.
- [x]
- [x] Audit the code for places that assume that parallel_insert_slice is the only possible terminator for a forall by doing a cast<> - change those to
dyn_cast<>+ failing if there's no match unless there's obviously a better approach here - [x] Implement tiling to threads for
iree_linalg_ext.gatherto use the coalesced DMA gather defined above - [x] Implement a lowering from this coalesced DMA op to
amdgpu.gather_to_ldson memrefs, including a fallback path for when the gather is not applicable - [ ] (In order to allow the above, we may need to ensure vectorization works for computing the gather index. This may involve teaching
affine.[de]linearize_indexto accept vector inputs and return vector outputs) - [x] create e2e tests to ensure the above pieces works together within the framework.
- [x] Ensure copies that we want to lower to DMAs (per lowering config, currently) also run down this flow
- [ ] Audit the previous flow for generating DMAs to see how much of it is dead code
Side note: if all this becomes easier if amdgpu.gather_to_lds can have tensor semantics, that's fine be me (the bufferization would just imply memory spaces or check for them)
Side side note: design doc
@qedawkins 's sketch of the setup
// Block level
%0 = iree_linalg_ext.gather
input = %input : tensor<M x K>
indices = %indices : tensor<mTile x kTile>
output = %dest : tensor<mTile x kTile>
// Subgroup level
scf.forall (%flat_subgroup_id) {
%m_id, %k_id = affine.delinearize_index %flat_subgroup_id into (mTile, kTile)
%indices_slice = tensor.extract_slice %input [%m_id, %k_id] [mTile, kTile]
%dest_slice = tensor.extract_slice %dest [%m_id, %k_id] [mTile, kTile]
%0 = iree_linalg_ext.gather
input = %input : tensor<M x K>
indices = %indices_slice : tensor<mTile x kTile>
output = %dest_slice : tensor<mTile x kTile>
scf.forall.in_parallel {
tensor.parallel_insert_slice %0 into %dest
}
}
// Lane level
scf.forall (%flat_subgroup_id) {
%m_id, %k_id = affine.delinearize_index %flat_subgroup_id into (mTile, kTile)
%indices_slice = tensor.extract_slice %input [%m_id, %k_id] [mTile, kTile]
%dest_slice = tensor.extract_slice %dest [%m_id, %k_id] [mTile, kTile]
%0 = scf.forall (%flat_thread_id) shared_outs(%shared_dest_slice = %dest_slice) {
%m_id, %k_id = affine.delinearize_index %flat_subgroup_id into (mTile, kTile)
%indices_thread_slice = tensor.extract_slice %indices_slice [%m_id, %k_id] [m, k]
scf.forall.in_parallel {
iree_gpu.coasceled_gather_dma
input = %input : tensor<M x K>
indices = %indices_thread_slice : tensor<m x k>
output = %shared_dest_slice : tensor<mTile x kTile>
}
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %0 into %dest
}
}
Design doc: https://gist.github.com/lialan/97900b4e3a54ec49b2a75bd27fa0bfe4
Notes from meeting:
- If the gather tensor actually spans into the inner dimensions of the source (that is, you're not gathering tensor slices) some non-trivial analysis will be needed to prove that we can turn this into a
gather_to_lds(since said gather only deals in 32- or 128-bit chunks) - This is best avoided by not doing an inner dimension gather (at which point, post-tiling, you have a weird species of copy) because then you know you have said contiguous slices