vllm
vllm copied to clipboard
[Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi
FILL IN THE PR DESCRIPTION HERE
FIX https://github.com/vllm-project/vllm/issues/4171
allow non-power-of-two head sizes in prefix prefill with alibi, this is a small fix based on https://github.com/vllm-project/vllm/pull/4128.
BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]
for bug fixes.[CI/Build]
for build or continuous integration improvements.[Doc]
for documentation fixes and improvements.[Model]
for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]
For changes on the vLLM frontend (e.g., OpenAI API server,LLM
class, etc.)[Kernel]
for changes affecting CUDA kernels or other compute kernels.[Core]
for changes in the core vLLM logic (e.g.,LLMEngine
,AsyncLLMEngine
,Scheduler
, etc.)[Hardware][Vendor]
for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]
).[Misc]
for PRs that do not fit the above categories. Please use this sparingly.
Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
- We adhere to Google Python style guide and Google C++ style guide.
- Pass all linter checks. Please use
format.sh
to format your code. - The code need to be well-documented to ensure future contributors can easily understand the code.
- Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
- Please add documentation to
docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required
and might not go through the PR.
What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
- After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
- After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
- After the review, the reviewer will put an
action-required
label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR. - Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!
LGTM FWIW!
WTAL soon! Can you update the test?
ok, I will add a test for prefix prefill kernel with alibi later.
The local test passed. Waiting for all kernel tests of CI to pass.
pytest test_prefix_prefill.py
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.13, pytest-8.1.0, pluggy-1.4.0
configfile: pyproject.toml
plugins: shard-0.1.2, rerunfailures-13.0, anyio-4.3.0, asyncio-0.23.5, forked-1.6.0
asyncio: mode=strict
collected 48 items
Running 48 items in this shard
test_prefix_prefill.py ................................................ [100%]
============================================================================= 48 passed in 157.21s (0:02:37) =============================================================================
@rkooo567 @simon-mo hi~ Can you take a look at this PR? All CI tests have passed. I need this fix for the bloom model. many thanks ~
sorry I am going to take a look at it today!
@rkooo567 alibi tests have been moved to test_prefix_prefill.py. Kernel tests in CI have passed.
So the logic change itself is equivalent to non-alibi case right?
Also, @WoosukKwon it'd be great if you can take a look at alibi slope testing. It looks okay to me, but just in case...
yes, the logic is the same as non-alibi case. The additional dim_mask collaborate with offs_m can generate right mask for tl.load. for example:
# dim_mask[None, :] & offs_m[:, None] -> [1,D] & [M, 1]
# >>> dim_mask = torch.tensor([1,1,1,1,0,0]) # [1,6]
# >>> offs_m = torch.tensor([8,9,10,11]) # [4,1]
# >>> mask=dim_mask[None, :] & (offs_m[:, None] < 12)
# >>> mask
# tensor([[1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 0, 0]])
# >>> mask.shape
# torch.Size([4, 6]) # [M,D]
q = tl.load(Q + off_q, # [M,D]
mask=dim_mask[None, :] &
(offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
I can add some comments if needed.
@rkooo567 seems many tests interrupted by a signal
can you try merge the latest master? I saw it sometimes happens... not sure what's the root cause