[Draft][Core] Refactor _prepare_model_input_tensors - take 2
This PR refactors _prepare_model_input_tensors. Specifically, we introduce ModelRunnerInputBuilder mainly for logic isolation and modularization. Specifically, ModelRunnerInputBuilder manages all processed input data, including token IDs, positions, sequence length, etc, in one place, and isolates the following logic:
The logic of inserting a new sequence group to input data, considering prefix caching, chunked prefill, sliding windows, etc.
3. The logic of preparing attention inputs.
4. The logic of preparing LoRA and multi-modal inputs.
5. The logic of creating on-device tensors for model inputs.
Note that the purpose of this PR is to enable follow-up refactoring and optimizations, so we don't expect an obvious performance improvement at this moment.
With this isolation, we could further have follow-up optimizations:
- Refactor
AttentionMetadatato only include on-device tensors, and move all related logic fromModelRunnerInputBuilder. - Remove the loop
for seq_id in seq_idsinModelRunnerInputBuilder._add_seq_group()by leveraging tensor processing. - Parallelize the loop
for seq_group_metadata in seq_group_metadata_list. - and more.
@rkooo567 @zhuohan123 @simon-mo @WoosukKwon @youkaichao @LiuXiaoxuanPKU I've done the first round of refactoring:
- The attention unrelated logic (tokens, sequence length, LoRA, MM, etc) remains in
prepare_input. - Keep prefill and decode logic together.
- Attention specific logic such as block table, slot mapping are moved to attention metadata builder.
- Flash attention and FlashInfer metadata builder are self-contained.
- xFormers / ROCmFlashAttention / BlockSparseAttention metadata builder share the same utility functions.
This PR is ready for review. I'll wait for CI to be green first and then rebase to resolve the conflict.
Remaining concern that could potentially be addressed in this PR: The arguments of attn_metadat_builcer.add_seq_group() is ugly. One reason is we have to compute sliding window sequence length outside of the attention metadata (because sequence length is common). However, we also need the original sequence length to compute block table and slot mapping inside the attention metadata.
Follow-up PRs: Move more attention related logic, such as dummy inputs for CUDA graph capturing and pre-/post-processing logic in forward.
feel free to merge it after addressing comments!