torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[RFC] Sharded embeddings in separate FSDP group

Open awgu opened this issue 1 year ago • 0 comments

Stack from ghstack (oldest at bottom):

  • -> #182

If we shard the embeddings as a separate FSDP parameter group, then:

  • In forward, we have a separate all-gather for the root first (norm, output projection) followed by an all-gather for the embeddings. This makes the first all-gather smaller and allows overlapping the embedding's pre-forward casts with an all-gather.
  • In forward, the embedding parameters are resharded after their usage near the beginning of forward, before any transformer block forwards.
  • In backward, the embedding parameters are only all-gathered near the end of backward when it is no longer close to peak memory.
  • In backward, the embedding's reduce-scatter and the root's reduce-scatter are still both exposed since the embedding has the last gradient computation.

This saves ~the embedding parameter size from peak memory without any decrease to WPS on the first order. (It introduces extra 2 all-gathers and 1 reduce-scatter, which can be bad for communication latency at large scale.)

For example, for Llama-7B with bf16 mixed precision, we save ~0.84 GiB, and on 8 GPUs, there is no noticeable effect on MFU.

awgu avatar Apr 01 '24 17:04 awgu