[QUESTION] Why use Remote TMA Load for gemm_rs sm90 implementations?
It seems to me that fetch from remote has obvious drawbacks: (1) One have to polling for flag before actually reading data. (2) Due to the limit of of only using one 128x32 stage in shared memory, the read performance will be low. And with TMA Reduce to remote hopper gpus, the advantages are also obvious (at least to me): (1) Remove the latency of polling remote flag. (2) Since TMA store (and reduce) supports early done mechanism, only one stage is sufficient for continuously writing data out.
So, I believe you must have your reasons for adopting current implementation, and my questions are:
(1) Why not directly reduce the D matrix to remote GPUs (like your implementation in sm80)? For deterministic addition of floating point, store rather than reduce could be applied as FuseReduction=False do.
(2) Why must other threadblocks wait for local tile to reduce first? Is this necessary? The code is below:
if constexpr (FuseReduction) { if (not is_local_tile_reduce) { // if this tile is fetched from other rank, wait for the local rank to reduce first Barrier::wait_lt(lock_ptr, thread_idx, flag_idx, 1); } }
Waiting for your kind reply :)
It turned out that this implementation leads to worse performance than no-fusion:
./launch.sh test/python/gemm_rs/test_gemm_rs.py 8192 12288 8192 --dtype=float16 --iters=10 torch #0: gemm 0.557 ms, comm 1.009 ms, total 1.566 ms flux #0: gemm 0.576 ms, comm 1.826 ms, total 2.402 ms