torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Refactor freqs_cis slice to be safer for PP

Open wconstab opened this issue 1 year ago • 0 comments

Stack from ghstack (oldest at bottom):

  • #318
  • #322
  • -> #321

Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level transformer based on the input token shape, we slice it down to seqlen inside a transformer layer after we have re-expanded to the full seqlen in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but we do not generally know this. That makes it hard for stage1 to slice freqs_cis correctly. It's easy to do the slicing deeper inside, since at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the thing passed into every layer is just a view. This change should not be material for memory usage or otherwise.

wconstab avatar May 10 '24 23:05 wconstab