pytorch-lightning
pytorch-lightning copied to clipboard
Support combinations of precision plugins
Description & Motivation
Both the Fabric and Trainer strategies are designed to have a single plugin enabled from the beginning to the end of the program.
This has been fine historically, however, some strategies require tailored plugin implementations that are functionally equal to other plugins.
For instance, single device training with bf16-true precision will use the HalfPrecision plugin
But FSDP training with bf16-true precision will use the FSDPPrecision plugin
Recent cutting-edge plugins such as TransformerEnginePrecision and BitsandbytesPrecision also implement the basic bf16-true functionality (example). So there's a lot of overlap.
The challenge then becomes: how do I enable TransformerEnginePrecision to work with FSDPStrategy if FSDPStrategy is designed to work with FSDPPrecision only?
Note that I'm using these specific classes to prove the point, but the design issue applies to any strategy that requires a specific plugin class. DeepSpeedStrategy and XLAStrategy would also be examples of this.
Pitch
Continuing the example, there's 3 ways this could be solved:
- The naive way: Create a
TransformerEngineFSDPPrecision. This is simple and effective but it creates maintainability problem. - The independent way: If there are no dependencies between the plugins, we could support
plugins=[TransformerEnginePrecision(), FSDPPrecision()]. But there will likely be dependencies. - The smart way: Create an abstraction that is able to compose two (or more?) plugins together and is itself a plugin. There's some precedent for this with the
CheckpointIOplugins.
Alternatives
No response
Additional context
No response
cc @borda @tchaton @justusschock @awaelchli @carmocca
@carmocca is there a way to use TransformerEnginePrecision with FSDP?
Not for now: https://github.com/NVIDIA/TransformerEngine/issues/401
@carmocca to understand it better, where is the bottleneck? does transformer engine has to support it or FSDP has to support this?
TransformerEngine, and then we would need to integrate whatever is changed into Lightning.
cc @sbhavani in case you know about the progress for this
Transformer Engine + FSDP functionally works but doesn't provide memory savings. We are working on FP8 support for PyTorch's FSDP implementation (i.e. understand FP8 tensors) upstream which would provide memory savings.
@sbhavani is there a timeline that you’re targeting for the feature to be available in FSDP?
On Mon, Oct 23, 2023 at 12:33 PM Santosh Bhavani @.***> wrote:
Transformer Engine + FSDP functionally works but doesn't provide memory savings. We are working on FP8 support for PyTorch's FSDP implementation (i.e. understand FP8 tensors) upstream which would provide memory savings.
— Reply to this email directly, view it on GitHub https://github.com/Lightning-AI/lightning/issues/18679#issuecomment-1775784841, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAS6OJEWZXRUNFGJTG3TAXLYA22BPAVCNFSM6AAAAAA5MWUUP6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZVG44DIOBUGE . You are receiving this because you are subscribed to this thread.Message ID: @.***>
@carmocca any plan to support FP8 training for deepspeed strategy?
@sbhavani I see https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp exists now. Is your last comment still valid?
@carmocca Yes upstream support for FP8 storage and comms. in PyT's FSDP implementation is still in progress. However, with lazy weight initialization in that example, you can use FP8 with FSDP.
I would like to suggest another use case that two different precision plugins may be needed. Suppose my lightning module is composed of a transformer and a CNN. Is it possible to use bf16 for the transformer and 16-mixed for the CNN?
@carmocca it might be worth adding FSDP+FP8 support via TE. HF Accelerate just added FSDP+FP8 support as a reference: https://github.com/huggingface/accelerate/blob/main/benchmarks/fp8/fsdp.py
@sbhavani if I get it right, as of today, the support for FP8 + FSDP is not available with PyTorch Lightning?
Correct. It's very much on our radar though. Anyone wants to run ahead and submit a PR?
I do think that right now:
Create a TransformerEngineFSDPPrecision. This is simple and effective but it creates maintainability problem.
is going to have the highest ROI. Once we have this fully working we can think about generalizing in the future.
Here's the starting point if anyone would like to give it a shot: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/plugins/precision/transformer_engine.py
@lantiga Any update on this? Is this also the reason why ModelParallelStrategy does not work with precision=transformer-engine?