sglang
sglang copied to clipboard
Sequence Parallel
Motivation
When serving an extremely large model (e.g. Llama 400B), the #GPU might be more than #kv head. This leads to a replication on kv cache, which is troublesome when the sequence length is too large
Modification
This PR introduced a very basic sequence parallelism on the attention computation. For all other parts, the model is still fully tensor parallelized. The partition switches before and after the attention. This is achieved by:
- When preparing the batch, collocate input ids on the same sequence parallel rank (
sp_rank
) together, this is referred as the sequence parallel layout in this pr and the code comments. Flash infer args are accordingly changed; - Before entering the SP part, only the KV locally stored is computed. (python/sglang/srt/layers/linear.py)
- The SP kernel, which still has some space to improve. (python/sglang/srt/layers/radix_attention.py)
- When leaving the SP part, the whole sequence is collected again, because the rest part takes the whole sequence.
- The output logits are switched back before doing the sampling.
- MISC modification including: Parallel State (python/sglang/srt/layers/parallel_utils/parallel_state.py), calling all components in model runner (python/sglang/srt/managers/controller/model_runner.py) and the model definition (python/sglang/srt/models/llama2.py), server args (python/sglang/srt/server_args.py), and tests
Checklist
- [ ] Ensure pre-commit
pre-commit run --all-files
or other linting tools are used to fix potential lint issues. - [ ] Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
- [ ] Modify documentation as needed, such as docstrings or example tutorials.