vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Draft][Core] Refactor _prepare_model_input_tensors - take 2

Open comaniac opened this issue 1 year ago • 1 comments

This PR refactors _prepare_model_input_tensors. Specifically, we introduce ModelRunnerInputBuilder mainly for logic isolation and modularization. Specifically, ModelRunnerInputBuilder manages all processed input data, including token IDs, positions, sequence length, etc, in one place, and isolates the following logic: The logic of inserting a new sequence group to input data, considering prefix caching, chunked prefill, sliding windows, etc. 3. The logic of preparing attention inputs. 4. The logic of preparing LoRA and multi-modal inputs. 5. The logic of creating on-device tensors for model inputs.

Note that the purpose of this PR is to enable follow-up refactoring and optimizations, so we don't expect an obvious performance improvement at this moment.

With this isolation, we could further have follow-up optimizations:

  1. Refactor AttentionMetadata to only include on-device tensors, and move all related logic from ModelRunnerInputBuilder.
  2. Remove the loop for seq_id in seq_ids in ModelRunnerInputBuilder._add_seq_group() by leveraging tensor processing.
  3. Parallelize the loop for seq_group_metadata in seq_group_metadata_list.
  4. and more.

comaniac avatar Jul 06 '24 01:07 comaniac

@rkooo567 @zhuohan123 @simon-mo @WoosukKwon @youkaichao @LiuXiaoxuanPKU I've done the first round of refactoring:

  1. The attention unrelated logic (tokens, sequence length, LoRA, MM, etc) remains in prepare_input.
  2. Keep prefill and decode logic together.
  3. Attention specific logic such as block table, slot mapping are moved to attention metadata builder.
  4. Flash attention and FlashInfer metadata builder are self-contained.
  5. xFormers / ROCmFlashAttention / BlockSparseAttention metadata builder share the same utility functions.

This PR is ready for review. I'll wait for CI to be green first and then rebase to resolve the conflict.

Remaining concern that could potentially be addressed in this PR: The arguments of attn_metadat_builcer.add_seq_group() is ugly. One reason is we have to compute sliding window sequence length outside of the attention metadata (because sequence length is common). However, we also need the original sequence length to compute block table and slot mapping inside the attention metadata.

Follow-up PRs: Move more attention related logic, such as dummy inputs for CUDA graph capturing and pre-/post-processing logic in forward.

comaniac avatar Jul 10 '24 22:07 comaniac

feel free to merge it after addressing comments!

rkooo567 avatar Jul 16 '24 17:07 rkooo567