sglang icon indicating copy to clipboard operation
sglang copied to clipboard

Sequence Parallel

Open ZYHowell opened this issue 6 months ago • 1 comments

Motivation

When serving an extremely large model (e.g. Llama 400B), the #GPU might be more than #kv head. This leads to a replication on kv cache, which is troublesome when the sequence length is too large

Modification

This PR introduced a very basic sequence parallelism on the attention computation. For all other parts, the model is still fully tensor parallelized. The partition switches before and after the attention. This is achieved by:

  1. When preparing the batch, collocate input ids on the same sequence parallel rank (sp_rank) together, this is referred as the sequence parallel layout in this pr and the code comments. Flash infer args are accordingly changed;
  2. Before entering the SP part, only the KV locally stored is computed. (python/sglang/srt/layers/linear.py)
  3. The SP kernel, which still has some space to improve. (python/sglang/srt/layers/radix_attention.py)
  4. When leaving the SP part, the whole sequence is collected again, because the rest part takes the whole sequence.
  5. The output logits are switched back before doing the sampling.
  6. MISC modification including: Parallel State (python/sglang/srt/layers/parallel_utils/parallel_state.py), calling all components in model runner (python/sglang/srt/managers/controller/model_runner.py) and the model definition (python/sglang/srt/models/llama2.py), server args (python/sglang/srt/server_args.py), and tests

Checklist

  • [ ] Ensure pre-commit pre-commit run --all-files or other linting tools are used to fix potential lint issues.
  • [ ] Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
  • [ ] Modify documentation as needed, such as docstrings or example tutorials.

ZYHowell avatar Aug 12 '24 02:08 ZYHowell