TransformerEngine
TransformerEngine copied to clipboard
``fp8_group`` when using FSDP and tensor parallelism
Hi,
What is the correct fp8_group when using FSDP and tensor parallelism together?
Is it all gpus or between tensor parallel groups?
Thanks.
The process group for FP8 amax reductions (fp8_group) should be the combination of the data-parallel and tensor-parallel groups, which is the world group in your use-case. This is because the activation and dgrad tensors need to have consistent FP8 scaling factors when they are distributed. In principle we could get away with doing FP8 amax reductions over just the tensor-parallel group in order to reduce the amount of global communication, but that makes the FP8 casts less stable and complicates checkpointing.
I see that the example in the docs uses fp8_group=data_parallel_group instead of world_group. This is confusing and should be fixed, even if it is running on one GPU and is technically correct.