torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[RFC] Lift freqs_cis as an input of models

Open XilunWu opened this issue 3 months ago • 0 comments

Stack from ghstack (oldest at bottom):

  • #1901
  • #1897
  • #1884
  • #1883
  • -> #1882

freqs_cis is sensitive to the sequence order. CP load balancing will shuffle the samples, so each batch will have different orders. As a result, we will have to lift these order senstive buffer to the inputs and broadcast them along the batch dimension so that PP will correctly shard freqs_cis without messing up the correctness.

Pull-Request-resolved: https://github.com/pytorch/torchtitan/pull/1797

XilunWu avatar Oct 15 '25 17:10 XilunWu