torchtitan
torchtitan copied to clipboard
[RFC][WIP][CP] Enable FlexAttention CP for llama3
Stack from ghstack (oldest at bottom):
- -> #1857
- #1939
This PR uses the latest CP APIs to enable FlexAttention + CP for llama3. This PR removes the usage of context_paralle() context manager and use _context_parallel_shard() to shard the input data.