TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2

Open vthumbe1503 opened this issue 4 days ago • 4 comments

Description

FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2

Fixes # (issue)

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [x] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refactoring

Changes

Please list the changes introduced in this PR:

  • For FSDP2, FP8/MXFP8 tensors implement fsdp_pre_all_gather method that splits our tensors into uint8 tensor and metadata, so that it can be reconstructed into unsharded fp8/mxfp8 tensor after allgatherting Uint8 data. While constructing the metadata, we create a copy of the quantizer, for pre_allgather call in each iteration of training. This PR fixes that to create copy only once to reduce CPU overhead.

    • Before Change image

    • After Change image

  • Using FusedAdam with FP8 FSDP2 results in segfault currently. Since we werent taking care of DTensor FP8 tensors(FSDP2 specific) while doing optimizer step and rather only FP8 tensors. This is fixed now and FusedAdam works as expected with fp8 and FSDP2.

Checklist:

  • [ ] I have read and followed the contributing guidelines
  • [ ] The functionality is complete
  • [ ] I have commented my code, particularly in hard-to-understand areas
  • [ ] I have made corresponding changes to the documentation
  • [ ] My changes generate no new warnings
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] New and existing unit tests pass locally with my changes

vthumbe1503 avatar Nov 12 '25 04:11 vthumbe1503