gpt-neox
gpt-neox copied to clipboard
Add support for context parallelism
Adds context parallelism with ring attention.
See WandB report for training runs to test correctness with a simple 410M config: https://wandb.ai/brandony/neox/reports/Test-context-parallelism-correctness--Vmlldzo5NTU4ODc1.
- Checked the following settings: no MP/CP, MP 4, CP 4, MP 2 and CP 2
- Confirmed that the loss exactly matches
- Memory and training speed seems reasonable
Based on this initial PR: https://github.com/EleutherAI/gpt-neox/pull/1266, with changes to get things working:
- All-reduce gradients across context-parallel nodes by piggy backing on DP
- Fix parallelism initialization
- Remove unnecessary code