vllm
vllm copied to clipboard
[CI/Build] VLM Test Consolidation
*This PR is large because it's touching pretty much all of the VLM tests - the place where they are consolidated into, i.e., where new model tests would be added is here. The place that each of these eventually land to run the vLLM/HF runners (for all of the tests types) is here!
Overview
Most of the multi-modal image/video tests in vLLM do the same thing, but have small tweaks to things like
- prompt format / model-specific multimodal tokens
- data types
- number of logprobs being compared
- decorators (e.g.,
@large_gpu_test
, which may control spawning the tests in a new process) and so on.
However, the structure of the tests themselves is very similar, and pretty much always consists of some boilerplate to configure instances of HfRunner
/ VllmRunner
and compare greedy logprobs in some common _run_test
function, which will then be wrapped in a test_
function with different options, e.g., size_factors
, being parametrized.
This PR aims to consolidate a lot of the redundancy in vLLM's multimodal tests, starting with the visual tests, but in a way that can also (hopefully) be easily extended to other types, e.g., audio
.
The way this is accomplished is by defining five types of test, which can cover most of our visual model tests:
- single image
- multiple images
- visual embeddings (i.e., vLLM gets image embeddings, HF model gets images)
- video
- custom inputs, i.e., some other prepepared input for some edge-case for specific models
and defining test wrappers for each of them, each of which leverage common utilities and invoke a common run_test
, which is written to be model-agnostic. As such, to add a test for a new type of model, you should be able to (mostly) add a new object configuring a model for the VLM_TEST_SETTINGS
& only need to worry about things that are model-specific, e.g., post-processing on the VLLM runner output.
Example 1: GLM-4
Here is an example of an object describing the GLM-4 tests.
"glm4": VLMTestInfo(
models="THUDM/glm-4v-9b",
test_type=VLMTestType.IMAGE,
prompt_formatter=identity,
fork_new_process_for_each_test=True,
img_idx_to_prompt=lambda idx: "",
max_model_len=2048,
max_num_seqs=2,
dtype="bfloat16",
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
skip=(get_memory_gb() < 48), # Large GPU test
patch_hf_runner=vlm_utils.glm_patch_hf_runner,
),
Which indicates that we should run a test single image test, using the default single image prompt / image token, without additional prompt formatting. This will actually result in four tests, because it runs the default size factors as separate pytest cases. Additionally, each test case will run in a separate process, similar to if we wrote a test with @large_gpu_test(min_gb=48)
to ensure resources are still cleaned up correctly.
Example 2: Phi3v
Phi3v is a model that supports multiple images, so it has single & mutli image tests. This is a pretty good example of what our multi-image tests look like most of the time, but there are a few quirks specific to the model, e.g., needing to use eager attention when running the HF model. The option below configures both the single & multi-image tests for the default size factors.
"phi3v": VLMTestInfo(
models="microsoft/Phi-3.5-vision-instruct",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|end|>\n<|assistant|>\n", # noqa: E501
img_idx_to_prompt=lambda idx: f"<|image_{idx}|>\n",
max_model_len=4096,
max_num_seqs=2,
# use eager mode for hf runner, since phi3v didn't work with flash_attn
model_kwargs={"_attn_implementation": "eager"},
use_tokenizer_eos=True,
vllm_output_post_proc=vlm_utils.phi3v_vllm_to_hf_output,
num_logprobs=10,
),
So as a result of the default size factors, when this gets picked up, it will run 4 single image tests (one per size factor) and 4 multi image tests (one per size factor). Settings like max_model_len
and max_num_seqs
etc which are common are used in all the test types.
Other Implementation details
-
Each
VLMTestInfo
has the flexibility to set the things we would normallyparametrize
, e.g.,size_factors
. This is accomplished by consuming all of theVLM_TEST_SETTINGS
, dropping things that are marked as skipped / tests that don't have the correspondingtest_type
, and then taking an itertools product over specific fields (e.g.,models
,num_tokens
,size_factors/fixed_sizes
,dtype
etc) for each test that matched so that each combination will run as its own test. -
Each test type has two implementations; the normal one, and a
heavy
version which is identical, but runs each test in a new process. If a test setsfork_new_process_for_each_test=True
, it will run in the heavy version, which is wrapped in the@fork_new_process_for_each_test
decorator.- This can be paired with the
skip
field to achieve the same effect of some of our existing decorators for cleaning up resources. For example, settingVLMTestInfo(..., fork_new_process_for_each_test=True, skip=(get_memory_gb() < 48))
will behave like@large_gpu_test(min_gb=48)
.
- This can be paired with the
-
A lot of the current tests use very slight variations of the same prompts. This PR is aligning most of them on a common prompts, so the scores will change slightly. The settings for most models should (hopefully) be the same though; for each model, I ensured that if the prompts matched, the vLLM runner/hf runner produced the same output as I was porting them.
-
In some cases, e.g., paligemma, the prompt is quite different, and using the default one fails the logprobs check. This is why in some cases the prompt is set directly for some models.
-
There are a couple of other tests here that didn't really fit (e.g., the quantized InternVL test and error case for Llava). For now, I left these where they were, since I thought it might be a good idea to figure out if we prefer this to somewhere In between (e.g., keeping in model's tests in its own file, but using common test runner here) first
-
To handle stuff like
audio
, I think it should be pretty similar to handlingimages
vs.video
; if we choose to make this sort of change, hopefully we can just add a newtest_type
to pass audio through the common wrapper for the vllm/hf runners -
For stuff like quantized tests - some thoughts for what would feel nice might be a good idea, but I imagine we probably could just pass the model / quantized model, and if its set, create two vLLM runners instead of a HF/vLLM runner
BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]
for bug fixes.[CI/Build]
for build or continuous integration improvements.[Doc]
for documentation fixes and improvements.[Model]
for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]
For changes on the vLLM frontend (e.g., OpenAI API server,LLM
class, etc.)[Kernel]
for changes affecting CUDA kernels or other compute kernels.[Core]
for changes in the core vLLM logic (e.g.,LLMEngine
,AsyncLLMEngine
,Scheduler
, etc.)[Hardware][Vendor]
for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]
).[Misc]
for PRs that do not fit the above categories. Please use this sparingly.
Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
- We adhere to Google Python style guide and Google C++ style guide.
- Pass all linter checks. Please use
format.sh
to format your code. - The code need to be well-documented to ensure future contributors can easily understand the code.
- Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
- Please add documentation to
docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
Adding or changing kernels
Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.
- Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
- Custom operations that return
Tensors
require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions. - Use
torch.libary.opcheck()
to test the function registration and meta-function for any registered ops. Seetests/kernels
for examples. - When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
- If a new custom type is needed, see the following document: Custom Class Support in PT2.
Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required
and might not go through the PR.
What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
- After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
- After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
- After the review, the reviewer will put an
action-required
label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR. - Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!