pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Support combinations of precision plugins

Open carmocca opened this issue 2 years ago • 12 comments
trafficstars

Description & Motivation

Both the Fabric and Trainer strategies are designed to have a single plugin enabled from the beginning to the end of the program.

This has been fine historically, however, some strategies require tailored plugin implementations that are functionally equal to other plugins.

For instance, single device training with bf16-true precision will use the HalfPrecision plugin But FSDP training with bf16-true precision will use the FSDPPrecision plugin Recent cutting-edge plugins such as TransformerEnginePrecision and BitsandbytesPrecision also implement the basic bf16-true functionality (example). So there's a lot of overlap.

The challenge then becomes: how do I enable TransformerEnginePrecision to work with FSDPStrategy if FSDPStrategy is designed to work with FSDPPrecision only?

Note that I'm using these specific classes to prove the point, but the design issue applies to any strategy that requires a specific plugin class. DeepSpeedStrategy and XLAStrategy would also be examples of this.

Pitch

Continuing the example, there's 3 ways this could be solved:

  1. The naive way: Create a TransformerEngineFSDPPrecision. This is simple and effective but it creates maintainability problem.
  2. The independent way: If there are no dependencies between the plugins, we could support plugins=[TransformerEnginePrecision(), FSDPPrecision()]. But there will likely be dependencies.
  3. The smart way: Create an abstraction that is able to compose two (or more?) plugins together and is itself a plugin. There's some precedent for this with the CheckpointIO plugins.

Alternatives

No response

Additional context

No response

cc @borda @tchaton @justusschock @awaelchli @carmocca

carmocca avatar Sep 29 '23 16:09 carmocca

@carmocca is there a way to use TransformerEnginePrecision with FSDP?

naveenkumarmarri avatar Oct 19 '23 03:10 naveenkumarmarri

Not for now: https://github.com/NVIDIA/TransformerEngine/issues/401

carmocca avatar Oct 19 '23 20:10 carmocca

@carmocca to understand it better, where is the bottleneck? does transformer engine has to support it or FSDP has to support this?

naveenkumarmarri avatar Oct 19 '23 21:10 naveenkumarmarri

TransformerEngine, and then we would need to integrate whatever is changed into Lightning.

cc @sbhavani in case you know about the progress for this

carmocca avatar Oct 19 '23 21:10 carmocca

Transformer Engine + FSDP functionally works but doesn't provide memory savings. We are working on FP8 support for PyTorch's FSDP implementation (i.e. understand FP8 tensors) upstream which would provide memory savings.

sbhavani avatar Oct 23 '23 18:10 sbhavani

@sbhavani is there a timeline that you’re targeting for the feature to be available in FSDP?

On Mon, Oct 23, 2023 at 12:33 PM Santosh Bhavani @.***> wrote:

Transformer Engine + FSDP functionally works but doesn't provide memory savings. We are working on FP8 support for PyTorch's FSDP implementation (i.e. understand FP8 tensors) upstream which would provide memory savings.

— Reply to this email directly, view it on GitHub https://github.com/Lightning-AI/lightning/issues/18679#issuecomment-1775784841, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAS6OJEWZXRUNFGJTG3TAXLYA22BPAVCNFSM6AAAAAA5MWUUP6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZVG44DIOBUGE . You are receiving this because you are subscribed to this thread.Message ID: @.***>

naveenkumarmarri avatar Oct 24 '23 02:10 naveenkumarmarri

@carmocca any plan to support FP8 training for deepspeed strategy?

naveenkumarmarri avatar Oct 26 '23 21:10 naveenkumarmarri

@sbhavani I see https://github.com/NVIDIA/TransformerEngine/tree/main/examples/pytorch/fsdp exists now. Is your last comment still valid?

carmocca avatar Feb 29 '24 16:02 carmocca

@carmocca Yes upstream support for FP8 storage and comms. in PyT's FSDP implementation is still in progress. However, with lazy weight initialization in that example, you can use FP8 with FSDP.

sbhavani avatar Mar 12 '24 17:03 sbhavani

I would like to suggest another use case that two different precision plugins may be needed. Suppose my lightning module is composed of a transformer and a CNN. Is it possible to use bf16 for the transformer and 16-mixed for the CNN?

function2-llx avatar Aug 14 '24 11:08 function2-llx

@carmocca it might be worth adding FSDP+FP8 support via TE. HF Accelerate just added FSDP+FP8 support as a reference: https://github.com/huggingface/accelerate/blob/main/benchmarks/fp8/fsdp.py

sbhavani avatar Aug 15 '24 16:08 sbhavani

@sbhavani if I get it right, as of today, the support for FP8 + FSDP is not available with PyTorch Lightning?

psr-ai avatar Aug 17 '24 12:08 psr-ai

Correct. It's very much on our radar though. Anyone wants to run ahead and submit a PR?

I do think that right now:

Create a TransformerEngineFSDPPrecision. This is simple and effective but it creates maintainability problem.

is going to have the highest ROI. Once we have this fully working we can think about generalizing in the future.

Here's the starting point if anyone would like to give it a shot: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/fabric/plugins/precision/transformer_engine.py

lantiga avatar Sep 16 '24 11:09 lantiga

@lantiga Any update on this? Is this also the reason why ModelParallelStrategy does not work with precision=transformer-engine?

radulescupetru avatar Sep 11 '25 07:09 radulescupetru