torchtitan
torchtitan copied to clipboard
[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel
Stack from ghstack (oldest at bottom):
- -> #437
Note: This PR is for showcasing purpose only and is almost a reverse of #190.
At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra aten.cat
after each collective.
Stats from @awgu:
for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)
Experiment on 8-layer debug_model
before:
after: