TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

``fp8_group`` when using FSDP and tensor parallelism

Open cavdard opened this issue 1 year ago • 1 comments

Hi,

What is the correct fp8_group when using FSDP and tensor parallelism together? Is it all gpus or between tensor parallel groups?

Thanks.

cavdard avatar Feb 05 '24 18:02 cavdard

The process group for FP8 amax reductions (fp8_group) should be the combination of the data-parallel and tensor-parallel groups, which is the world group in your use-case. This is because the activation and dgrad tensors need to have consistent FP8 scaling factors when they are distributed. In principle we could get away with doing FP8 amax reductions over just the tensor-parallel group in order to reduce the amount of global communication, but that makes the FP8 casts less stable and complicates checkpointing.

I see that the example in the docs uses fp8_group=data_parallel_group instead of world_group. This is confusing and should be fixed, even if it is running on one GPU and is technically correct.

timmoon10 avatar Feb 29 '24 21:02 timmoon10