TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

FSDP: How to do all-gather using FP8?

Open vgoklani opened this issue 1 year ago • 3 comments

FSDP2 supports all-gather using FP8:

https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323

Wondering if we could do this directly using TransformerEngine instead of torch-ao?

Thanks!

vgoklani avatar Sep 17 '24 04:09 vgoklani

Hi @vgoklani -- TE modules can be initialized under the with te.fp8_model_init(): context to allocate their primary weights in FP8 (as te.Float8Tensors) instead of allocating at a higher precision and maintaining separate FP8 buffers for compute.

I don't believe anyone has tried this in practice, but at least in principle, FSDP2's per-parameter sharding should work out-of-the-box with the torch.uint8 data underneath our te.Float8Tensors.

There are two things to be mindful of here:

  1. You would not use the precompute_float8_dynamic_scale_for_fsdp(model) API from the linked example because TE already does this internally. You simply need to pass the process group for amax reductions (typically global/world group) into the te.fp8_autocast() context.
  2. In the absence of native FP8 support in PyTorch, you cannot apply the optimizer step directly onto the FP8 model parameters. Consequently, te.fp8_model_init() is intended to be used with higher precision "master" copies of the model parameters in the optimizer.

If you experiment with TE + FSDP2, please share your experiences. We already support PyTorch's native FSDP but this involves TE modules carrying extra FP8 buffers for the compute while FSDP communication remains in higher precision. It would be great to extend our FSDP support to te.fp8_model_init() + FSDP2.

denera avatar Sep 17 '24 20:09 denera

Adding to this, FSDP support should just be a matter of implementing fsdp_pre_all_gather and fsdp_post_all_gather methods in Float8Tensor, at least in principle.

timmoon10 avatar Sep 18 '24 23:09 timmoon10

I also have the interest on FP8 all-gather

zigzagcai avatar Mar 14 '25 10:03 zigzagcai