torchtitan
torchtitan copied to clipboard
Refactor freqs_cis slice to be safer for PP
Stack from ghstack (oldest at bottom):
- #318
- #322
- -> #321
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given batch.
Changed: instead of slicing self.freqs_cis down to seqlen at top level transformer based on the input token shape, we slice it down to seqlen inside a transformer layer after we have re-expanded to the full seqlen in cases where TP has sharded across seqlen.
In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but we do not generally know this. That makes it hard for stage1 to slice freqs_cis correctly. It's easy to do the slicing deeper inside, since at that point we do know the full seqlen unambiguously.
Note: the full self.freqs_cis is stored in memory either way, and the thing passed into every layer is just a view. This change should not be material for memory usage or otherwise.