[BUG] Weird Behavior
Describe the bug Negative communication time which is unexpected.
To Reproduce
torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 --rdzv_endpoint=127.0.0.1:23456 test/test_ag_kernel.py 1024 57344 8192 --dtype=bfloat16 --iters=100
Behavior
SOL time for GEMM(M=1024,N=57344,K=8192,TP=8): 0.122ms
torch #0: total 264.296 us, gemm 188.893 us, comm 75.403 us
torch #1: total 264.395 us, gemm 186.656 us, comm 77.739 us
torch #2: total 264.337 us, gemm 183.318 us, comm 81.019 us
torch #3: total 264.520 us, gemm 183.794 us, comm 80.726 us
torch #4: total 263.984 us, gemm 184.681 us, comm 79.303 us
torch #5: total 264.138 us, gemm 184.739 us, comm 79.399 us
torch #6: total 264.050 us, gemm 186.116 us, comm 77.933 us
torch #7: total 264.118 us, gemm 183.485 us, comm 80.633 us
flux #0: total 250.066 us, gemm 384.129 us, comm -134.063 us, gemm_only 191.983 us
flux #1: total 250.059 us, gemm 404.430 us, comm -154.371 us, gemm_only 192.030 us
flux #2: total 250.074 us, gemm 421.114 us, comm -171.040 us, gemm_only 192.004 us
flux #3: total 250.023 us, gemm 379.205 us, comm -129.183 us, gemm_only 192.094 us
flux #4: total 250.127 us, gemm 417.929 us, comm -167.803 us, gemm_only 191.965 us
flux #5: total 250.164 us, gemm 387.587 us, comm -137.423 us, gemm_only 192.135 us
flux #6: total 249.988 us, gemm 400.141 us, comm -150.154 us, gemm_only 192.068 us
flux #7: total 250.242 us, gemm 386.544 us, comm -136.302 us, gemm_only 192.158 us
Environment 8-H100 gpus
flux #0: total 250.066 us, gemm 384.129 us, comm -134.063 us, gemm_only 191.983 us
- total is measured with AG+GEMM
- gemm is measured with a separated GEMM only implementation
- comm is total - gemm
- gemm_only is measured by AG+GEMM but without AG
maybe this shape is tuned for AG+GEMM, but not tuned for the separated GEMM only.
and it's confusing. maybe we should re-define this fields and makes it more clear. @wenlei-bao
flux #0: total 250.066 us, gemm 384.129 us, comm -134.063 us, gemm_only 191.983 us
- total is measured with AG+GEMM
- gemm is measured with a separated GEMM only implementation
- comm is total - gemm
- gemm_only is measured by AG+GEMM but without AG
maybe this shape is tuned for AG+GEMM, but not tuned for the separated GEMM only.
and it's confusing. maybe we should re-define this fields and makes it more clear. @wenlei-bao
@houqi Agree. We probably could delete that field and bring it back with redefine, maybe comm = max(0, total-gemm)?