[PyTorch] Distributed intermediate/activation tensors for FSDP
torch.distributed.fsdp.FullyShardedDataParallel cannot scatter/gather the intermediate/activation tensors that TE modules pack into the autograd context at the end of their forward passes, resulting in globally sized activation and Fp8 weight tensors staying in memory.
This PR provides a te.distributed.prepare_te_modules_for_fsdp(fsdp_root) API that inserts references to the correct FSDP process group into FSDP-wrapped TE modules in a given model. The TE modules then use these process groups to scatter the intermediate/activation tensors at the end of the forward pass before packing them into the autograd context. The same tensors are gathered in the beginning of the backward pass before compute.
Using te.distributed.checkpoint() turns off the scatters/gathers to avoid unnecessary comm for tensors that need to be recomputed anyway.