TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

FSDP support

Open yongyanrao opened this issue 1 year ago • 16 comments

I was wondering if PyTorch's FullyShardedDataParallel (FSDP) is supported by TransformerEngine , especially if FP8 can work with FSDP. Thank you in advance.

yongyanrao avatar Aug 25 '23 16:08 yongyanrao

In the current scheme, Transformer Engine modules use standard parameter tensors in standard dtypes (FP32/BF16/FP16). Optimizers typically require higher precision than FP8 to get good learning behavior. I don't see anything that would disrupt FSDP and I've been able to get TE and FSDP working together in some quick experiments.

timmoon10 avatar Aug 25 '23 20:08 timmoon10

Hi Tim, thank you for the response. Did you try FP8 + FSDP yet?

yongyanrao avatar Aug 29 '23 16:08 yongyanrao

Hello, @timmoon10 which FSDP did you use ? Fairscale's ?

MatthieuToulemont avatar Aug 30 '23 12:08 MatthieuToulemont

Hi, I was referring to pytorch's FullyShardedDataParallel.

yongyanrao avatar Aug 30 '23 15:08 yongyanrao

Thank you !

MatthieuToulemont avatar Aug 30 '23 15:08 MatthieuToulemont

Yep, I used PyTorch FSDP with TE FP8. Be advised I haven't done full convergence experiments, just some basic sanity checking.

timmoon10 avatar Aug 30 '23 21:08 timmoon10

What is the recommendation for MixedPrecision when using FP8 with FSDP @timmoon10 ?

from torch.distributed.fsdp import MixedPrecision

precision = ?

MixedPrecision(
    param_dtype=precision,  # should we be forcing dtype here?
    reduce_dtype=torch.float32,  # reduce in FP32 as with AMP?
    buffer_dtype=precision,  # buffers in FP32?
    cast_forward_inputs=is_amp_enabled,  # should we cast the forward samples?
)

jramapuram avatar Sep 01 '23 15:09 jramapuram

Transformer Engine manages FP8 casting internally (see transformer_engine.pytorch.fp8_autocast) and it can run into problems when combined with other mixed precision tools like torch.autocast or torch.distributed.fsdp.MixedPrecision. For the moment, FSDP mixed precision should only be used to handle the case where param_dtype and reduce_dtype do not match, e.g.:

from torch.distributed.fsdp import MixedPrecision
import transformer_engine.pytorch as te

precision = torch.bfloat16
model = te.Linear(32, 32, params_dtype=precision)

fsdp_mixed_precision = MixedPrecision(
    param_dtype=precision,
    reduce_dtype=torch.float32,
    buffer_dtype=precision,
)

Edit: On second thought, I'm skeptical about this post. I don't remember my main concerns and I've been able to have fp8_autocast working with torch.autocast and torch.distributed.fsdp.MixedPrecision. See https://github.com/NVIDIA/TransformerEngine/issues/438#issuecomment-1743693330.

timmoon10 avatar Sep 01 '23 23:09 timmoon10

@timmoon10 I think the community would greatly benefit from an e2e example on how to train a 7B model on H100 node(s). This is the main purpose of this lib - to train large models quickly. Yet there is not a single example on how to do it - instead we have a MNIST example... Ideally, a simple pure PyTorch+TE file without any "framework" dependencies like compose or fabric.

PiotrDabkowski avatar Sep 02 '23 16:09 PiotrDabkowski

@PiotrDabkowski, please take a look at NeMo, which provides the exact scripts and tools to be able to do this.

ksivaman avatar Sep 05 '23 15:09 ksivaman

@ksivaman : AFAIK NeMo uses tensor parallelism which requires manual changes to layers and arch whereas FSDP is pretty generic and more easy to use.

jramapuram avatar Sep 16 '23 21:09 jramapuram

@jramapuram were you able to train FSDP with FP8?

naveenkumarmarri avatar Oct 19 '23 03:10 naveenkumarmarri

@yongyanrao #596 recently added support for deferred initialization via device='meta' to improve FSDP support for large models. This feature delays memory allocation on device until the FSDP wrap internally calls reset_parameters() after sharding the model weights, which ensures that TE modules are not duplicated on every device upon initialization.

We also introduced a small FSDP example in the repo as part of this effort to demonstrate how to use FSDP with TE modules. It works out of the box with fp8_autocast(), with the requirement that fp8_autocast(...,fp8_group=...) gets the same process group as FSDP (both will default to the world group if none given).

denera avatar Jan 24 '24 16:01 denera

@denera Thank you for the feature! This will be super helpful for us too. Just a question, are there any additional steps needed to take to make primary_weights_in_fp8 work with FSDP? Or should it work out of the box?

Thank you!

denizokt avatar Jan 30 '24 15:01 denizokt

@denizokt fp8_model_init() is not supported with FSDP at the moment.

NCCL itself does not support 8-bit floats (see this discussion for more detail) and FSDP needs to upcast TE Fp8 weights to Fp16 for all-gathers and reduce-scatters, which it cannot do until PyTorch starts natively supporting Fp8 tensors.

The workaround for this limitation is what already happens in TE when primary_weights_in_fp8 = False. TE modules maintain their own Fp8 weight copies, update them with the primary Fp16/Bf16 weights during the forward pass, and stash the Fp8 transposed-weights into the PyTorch autograd context to re-use it during the backward pass.

denera avatar Jan 30 '24 18:01 denera

@denera Thank you for the information, this makes a lot of sense.

denizokt avatar Jan 30 '24 19:01 denizokt