transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Fix excessive CPU memory usage with FSDP and cpu_ram_efficient_loading

Open matthewdouglas opened this issue 1 year ago • 1 comments

What does this PR do?

This PR fixes an issue with FSDP + CPU_RAM_EFFICIENT_LOADING where a copy of the parameters are loaded into CPU memory for each rank. The change offloads to CPU only for rank 0, and the rest on the meta device. On a typical 8x node this will dramatically decrease the system RAM overhead required to load a large model.

This is split from a previously reverted PR #32276 originally contributed by @winglian. The revert was due to issues we had with validating the change that have since been resolved.

Fixes #31721, #31577

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

@ArthurZucker @LysandreJik

matthewdouglas avatar Aug 27 '24 20:08 matthewdouglas

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Ping me when this is ready for review!

ArthurZucker avatar Aug 28 '24 09:08 ArthurZucker

@ArthurZucker Ready!

matthewdouglas avatar Aug 28 '24 14:08 matthewdouglas

thanks @matthewdouglas !

winglian avatar Aug 28 '24 15:08 winglian

@ArthurZucker I've added more background to the description.

The issue we encountered was specific to our cluster environment on AWS. With the AWS EFI plugin for NCCL, we encountered consistent hangs. If we upgrade NCCL from the version bundled with PyTorch (2.20.5) to NCCL 2.22.3 via pip install nvidia-nccl-cu12==2.22.3, this issue is resolved. (Internal discussion)

matthewdouglas avatar Sep 04 '24 16:09 matthewdouglas

@ArthurZucker @matthewdouglas I tried this fix but im having similar NCCL issues as what you had. Unfortunately your suggestion to upgrade to latest is not working. I understand you have some internal debugging discussions on this topic. Is it possible to share NCCL env settings and other package versions, that might shed light on the root cause?

Update: found the root cause and it was not an NCCL issue. Have submitted a fix to TRL for it

fabianlim avatar Sep 13 '24 01:09 fabianlim