TransformerEngine
TransformerEngine copied to clipboard
[PYTORCH::FP8] FP8 significantly slow down when scaling up to 1000+ GPUs
Small LLMs trained using FP8 with 32 GPUs can achieve 20~30% speed up comparing with bf16. However, scaling up to 1000+ GPUs only achieve less than 5% speed up (TP2 PP4 VP4).
Any suggestion to deal with such problem?? Any plan to support FP8 all2all?
Thanks.
Hi @Ageliss, could you share more details on your training setup? Most probable reason for the lower speedup observed is that there are other bottlenecks (most probably communication, since the issue occurs at the large scales) which start to dominate your execution time.
Since you are asking about all2all performance - are you interested in MoE types of models? Generally speaking, as long as the scaling factors are kept in sync, FP8 all2all could be successfully emulated by int8 all2all (since there is no reduction or other mathematical operations involved in all2all, it is just moving bytes).
Hi @Ageliss, could you share more details on your training setup? Most probable reason for the lower speedup observed is that there are other bottlenecks (most probably communication, since the issue occurs at the large scales) which start to dominate your execution time.
Since you are asking about all2all performance - are you interested in MoE types of models? Generally speaking, as long as the scaling factors are kept in sync, FP8 all2all could be successfully emulated by int8 all2all (since there is no reduction or other mathematical operations involved in all2all, it is just moving bytes).
Thanks for your response. Yes, we used MoE models.
Training setup: These days I carried out several setups and found out that reducing TP2 to TP1 will rescue most of the speedup, i.e., from 5% to 15%. We checked out the timeline but found something strange that:
- Using TP2PP4, gather + all2all would be slower than bf16, -30%
- Using TP1PP4, gather + all2all would be faster than bf16, +17%
Also, we tried the FP8 all2all but the overhead is not small (we first compute row-max then divide row max before cast_to_fp8, after all2all + fp_gemm, we needed again multiply row-max). When using less than 128 cards, we did not observe accelerating.
So, the question is how TP2 affect FP8 and made it slow? Thanks for your last response again.
Hi @Ageliss, could you clarify whether the -30% and +17% numbers you provided are for just the gather+all2all operation or for the entire network?
In general, TP2 will make the per-GPU work smaller, which could affect a few things:
- the communication could become more latency-bound as opposed to bandwidth-bound, which would limit the speedup from using smaller datatype
- the GPU workload could become too small to cover the CPU overheads of Python execution and launching the kernels. That would make speedup coming from FP8 irrelevant, since the GPU would no longer be the limiting factor. What is more, in FP8 we need to launch more kernels than in BF16, since there are additional casting operations happening, so this could disproportionately affect FP8 execution and result in FP8 being slower than BF16.
In general it is best to profile the execution to see where the slowdown is coming from. The CPU overhead would show up on the profile as either the portions of time when GPU is idle (or only the NCCL communication kernel is running) and the host-side CUDA kernel launches of the computation kernels happening just before the actual kernels.
Hi, @ptrendx , -30% and +17% numbers are for the end2end training speed.
We checked the timeline and found gemm did not occupy too much on one step, less than 20%, most of time (60%+) GPU is idle, maybe waiting for nccl or others...
Thanks again.