torchtitan
torchtitan copied to clipboard
[RFC][WIP][CP] Enable FlexAttention CP for llama3
Stack from ghstack (oldest at bottom):
- #1901
- #1897
- #1884
- -> #1883
- #1882
This PR uses the latest CP APIs to enable FlexAttention + CP for llama3. This PR removes the usage of context_paralle() context manager and use _context_parallel_shard() to shard the input data.
Pull-Request: https://github.com/pytorch/torchtitan/pull/1857