torchtitan
torchtitan copied to clipboard
[RFC] Lift freqs_cis as an input of models
Stack from ghstack (oldest at bottom):
- #1901
- #1897
- #1884
- #1883
- -> #1882
freqs_cis is sensitive to the sequence order. CP load balancing will shuffle the samples, so each batch will have different orders. As a result, we will have to lift these order senstive buffer to the inputs and broadcast them along the batch dimension so that PP will correctly shard freqs_cis without messing up the correctness.
Pull-Request-resolved: https://github.com/pytorch/torchtitan/pull/1797