torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel

Open tianyu-l opened this issue 7 months ago • 1 comments

Stack from ghstack (oldest at bottom):

  • -> #437

Note: This PR is for showcasing purpose only and is almost a reverse of #190.

At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra aten.cat after each collective.

Stats from @awgu:

for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%)

Experiment on 8-layer debug_model before: image after: image

tianyu-l avatar Jul 04 '24 04:07 tianyu-l