TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Distributed intermediate/activation tensors for FSDP

Open denera opened this issue 1 year ago • 0 comments

torch.distributed.fsdp.FullyShardedDataParallel cannot scatter/gather the intermediate/activation tensors that TE modules pack into the autograd context at the end of their forward passes, resulting in globally sized activation and Fp8 weight tensors staying in memory.

This PR provides a te.distributed.prepare_te_modules_for_fsdp(fsdp_root) API that inserts references to the correct FSDP process group into FSDP-wrapped TE modules in a given model. The TE modules then use these process groups to scatter the intermediate/activation tensors at the end of the forward pass before packing them into the autograd context. The same tensors are gathered in the beginning of the backward pass before compute.

Using te.distributed.checkpoint() turns off the scatters/gathers to avoid unnecessary comm for tensors that need to be recomputed anyway.

nn.Sequential( 3 x te.LayerNormMLP ) before Fp8/intermediate sharding:

no_fp8_sharding

nn.Sequential( 3 x te.LayerNormMLP ) after Fp8/intermediate sharding:

with_fp8_sharding

denera avatar Feb 28 '24 01:02 denera