Add Semaphore Support for `cp.async` loads (Non-TMA Load Patterns)
This PR introduces semaphore support for non-TMA load_async operations by leveraging the PTX instruction cp.async.mbarrier.arrive.noinc.shared::cta.b64. The change aims to simplify producer-consumer kernels with non-standard load patterns that cannot be completed by the TMA.
motivation
We developed sparse matmul kernels that required using cp.async instead of TMA due to unique memory layout requirements. Currently, producer-consumer kernels force the producer to call cp.async.wait_all and manually signal the semaphore (e.g. FFTConv kernel). On a generic matmul kernel, our tests show that manually waiting on a semaphore with cp.async.wait_all plus an explicit arrive(bar) is over 200 TFLOPS slower than allowing cp.async to automatically signal the semaphore.
note on semaphores:
The PTX instruction cp.async.mbarrier.arrive.noinc.shared::cta.b64 ensures that once all non-committed cp.async operations from the current thread finish, that thread automatically arrives at the semaphore. Until then, it can work on other tasks. For example, when warpgroup::load_async is called with a semaphore, the expected arrival count is 128 (32 threads per warp * 4 warps). Detailed explanations are provided in the updated library comments.
changes
- Non-TMA
load_asyncoperations can now automatically work with semaphores by accepting an optional semaphore parameter. - Updated load strategies in 4 areas:
- Tile - warp level
- Tile - group level
- Vector - warp level
- Vector - group level
- Added tests to ensure correctness of the new operations.
usage
producer:
__shared__ sempahore bar;
if (threadIdx.x == 0) init_semaphore(bar, 128, 0);
__syncthreads();
load_async(shared_tile, global_tensor, {i, j, k, l}, bar);
consumer:
int tic = 0;
wait(bar, tic);
tic ^= 1;