[core] try to remove seq group from core
try to hide seq group from the core, by handling parallel sampling in llm engine.
👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can do one of these:
- Add
readylabel to the PR - Enable auto-merge.
🚀
~~the caveat is that we cannot support streaming when n > 1 . I think people don't use streaming when n > 1, and it is not clearly defined.~~
~~say we have n = 5, and the first stream gives 5 tokens, and then sequence 2 finish, do we send 5 outputs with the 2nd as empty? or send 4 outputs and let users mantain the status?~~
the openai api behavior is:
every sequence in parallel sampling will be assigned a unique index, and then the stream is flattened, one token at a time. it does not need to have n tokens at a time.
this is the test script:
from openai import OpenAI
api_key = ''
client = OpenAI(
api_key=api_key,
)
stream = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Repeat after me: apple."}],
stream=True,
max_tokens=5,
n=1,
)
for chunk in stream:
print(chunk)
and the output:
ChatCompletionChunk(id='chatcmpl-AJDmVXderuQ5vwRKIqOCT2vJsEydq', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, role='assistant', tool_calls=None, refusal=None), finish_reason=None, index=0, logprobs=None)], created=1729144571, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDmVXderuQ5vwRKIqOCT2vJsEydq', choices=[Choice(delta=ChoiceDelta(content='Apple', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1729144571, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDmVXderuQ5vwRKIqOCT2vJsEydq', choices=[Choice(delta=ChoiceDelta(content='.', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1729144571, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDmVXderuQ5vwRKIqOCT2vJsEydq', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1729144571, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
when I use n=2:
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, role='assistant', tool_calls=None, refusal=None), finish_reason=None, index=0, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='Apple', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='', function_call=None, role='assistant', tool_calls=None, refusal=None), finish_reason=None, index=1, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='Apple', function_call=None, role=None, tool_calls=None), finish_reason=None, index=1, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='.', function_call=None, role=None, tool_calls=None), finish_reason=None, index=0, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content='.', function_call=None, role=None, tool_calls=None), finish_reason=None, index=1, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
ChatCompletionChunk(id='chatcmpl-AJDo9qfpbDuIFeHHjJtkkMxvVhs2P', choices=[Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=None), finish_reason='stop', index=1, logprobs=None)], created=1729144673, model='gpt-4o-mini-2024-07-18', object='chat.completion.chunk', system_fingerprint='fp_e2bde53e6e')
I get two tokens from sequence 0 at first, and then two tokens from sequence 1.
@robertgshaw2-neuralmagic can you help take a look? I met a strange error:
pytest -v -s tests/entrypoints/openai/test_completion.py::test_guided_json_completion[-outlines]
will fail in this implementation. lm-format-enforcer works well, and --disable-frontend-multiprocessing also works. only the combination of the mqllmengine and outlines does not work.
it is surprising that ci actually passes ... it errors in my local dev machine.
@youkaichao how would this PR impact best_of > 1 requests? Is best_of functionality still within the engine, or is it moved outside the engine as has been done for beam search? @robertgshaw2-neuralmagic @njhill
@afeldman-nm best_of > 1 is already converted to parallel sampling in https://github.com/vllm-project/vllm/pull/9261/
Could you explain the benefit of doing so? It seems that with this change, the scheduler can no longer make decisions based on the number of sequences within a SequenceGroup.
Could you explain the benefit of doing so? It seems that with this change, the scheduler can no longer make decisions based on the number of sequences within a
SequenceGroup.
yes, the scheduler will only process single sequence in the future, to make the core code simple.
This modification makes the "fork" mechanism of vLLM completely unused. Previously, for a request with n > 1, its prompt was prefilled only once, and then the sequence was "forked" into n sequences to avoid redundant computation. After this modification, a request with n > 1 has to prefill its prompt n times. A small experiment code can be used to verify this. Notice how it has become much slower after this squashed commit.
from vllm import LLM, SamplingParams
import time
# Sample prompts.
prompts = [
"Once upon a time, there was a king.",
]
# Create a sampling params object.
sampling_params = SamplingParams(seed=42, temperature=0.1, max_tokens=1, n=100)
# Create an LLM.
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct")
# warm up
outputs = llm.generate(prompts, sampling_params)
begin_time = time.time()
outputs = llm.generate(prompts, sampling_params)
end_time = time.time()
print(f"{end_time - begin_time}s")
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
This modification makes the "fork" mechanism of vLLM completely unused. Previously, for a request with n > 1, its prompt was prefilled only once, and then the sequence was "forked" into n sequences to avoid redundant computation. After this modification, a request with n > 1 has to prefill its prompt n times.
Yes, this is intended. Please use prefix caching to speed up and share the prefill. All the sharing will not be hardcoded in the scheduler, and will only happen through prefix caching.
I'm not sure if prefix caching currently supports sharing in the same batch. If you want optimal performance, I would suggest running a n=1 request first, and then run another n=n-1 request.
This modification makes the "fork" mechanism of vLLM completely unused. Previously, for a request with n > 1, its prompt was prefilled only once, and then the sequence was "forked" into n sequences to avoid redundant computation. After this modification, a request with n > 1 has to prefill its prompt n times.
Yes, this is intended. Please use prefix caching to speed up and share the prefill. All the sharing will not be hardcoded in the scheduler, and will only happen through prefix caching.
I'm not sure if prefix caching currently supports sharing in the same batch. If you want optimal performance, I would suggest running a
n=1request first, and then run anothern=n-1request.
Thank you for clarifying. Prefix caching does support sharing in the same batch, though the performance gain is not as much as using the "fork" mechanism.
This modification makes the "fork" mechanism of vLLM completely unused. Previously, for a request with n > 1, its prompt was prefilled only once, and then the sequence was "forked" into n sequences to avoid redundant computation. After this modification, a request with n > 1 has to prefill its prompt n times.
Yes, this is intended. Please use prefix caching to speed up and share the prefill. All the sharing will not be hardcoded in the scheduler, and will only happen through prefix caching.
I'm not sure if prefix caching currently supports sharing in the same batch. If you want optimal performance, I would suggest running a
n=1request first, and then run anothern=n-1request.
This PR reduces our VLLM throughput for n>1 by about 3x, making this and later versions completely unusable. Enabling prefix caching does not make any difference in our performance tests.