[Bug]: Multistep with n>1 Fails
Your current environment
The output of `python collect_env.py`
Your output of `python collect_env.py` here
🐛 Describe the bug
Launched server with:
vllm serve $MODEL --num-scheduler-steps 8
Sent the following request:
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# Completion API
stream = False
completion = client.completions.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
stream=stream)
print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)
Got the following output:
INFO: Finished server process [1668044]
INFO 08-28 19:29:45 server.py:222] vLLM ZMQ RPC Server was interrupted.
Future exception was never retrieved
future: <Future finished exception=RuntimeError('shape mismatch: value tensor of shape [2] cannot be broadcast to indexing result of shape [1, 1]')>
Traceback (most recent call last):
File "/home/rshaw/vllm/vllm/entrypoints/openai/rpc/server.py", line 111, in generate
async for request_output in results_generator:
File "/home/rshaw/vllm/vllm/engine/async_llm_engine.py", line 1050, in generate
async for output in await self.add_request(
File "/home/rshaw/vllm/vllm/engine/async_llm_engine.py", line 110, in generator
raise result
File "/home/rshaw/vllm/vllm/engine/async_llm_engine.py", line 52, in _log_task_completion
return_value = task.result()
^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/engine/async_llm_engine.py", line 916, in run_engine_loop
result = task.result()
^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/engine/async_llm_engine.py", line 859, in engine_step
request_outputs = await self.engine.step_async(virtual_engine)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/engine/async_llm_engine.py", line 346, in step_async
output = await self.model_executor.execute_model_async(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/executor/gpu_executor.py", line 178, in execute_model_async
output = await make_async(self.driver_worker.execute_model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/.pyenv/versions/3.11.9/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/worker/worker_base.py", line 327, in execute_model
output = self.model_runner.execute_model(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/worker/multi_step_model_runner.py", line 275, in execute_model
output = self._base_model_runner.execute_model(frozen_model_input,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/worker/model_runner.py", line 1489, in execute_model
output: SamplerOutput = self.model.sample(
^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/model_executor/models/llama.py", line 447, in sample
next_tokens = self.sampler(logits, sampling_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/model_executor/layers/sampler.py", line 153, in forward
sample_results, maybe_sampled_tokens_tensor = _sample(
^^^^^^^^
File "/home/rshaw/vllm/vllm/model_executor/layers/sampler.py", line 771, in _sample
return _sample_with_torch(
^^^^^^^^^^^^^^^^^^^
File "/home/rshaw/vllm/vllm/model_executor/layers/sampler.py", line 633, in _sample_with_torch
sampled_token_ids_tensor[long_sample_indices] = \
~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape mismatch: value tensor of shape [2] cannot be broadcast to indexing result of shape [1, 1]
Before submitting a new issue...
- [X] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
I will take a look later today
Looks like @tdoublep encountered this issue a while ago in the context of speculative deocding and has a PR with a fix (that would need to be rebased):
- Issue: https://github.com/vllm-project/vllm/issues/6137
- PR: https://github.com/vllm-project/vllm/pull/6138
I also found a couple other issues for the same crash:
- https://github.com/vllm-project/vllm/issues/8261
- https://github.com/vllm-project/vllm/issues/4934
cc @afeldman-nm
I'm running into the same issue. Does anyone know of a workaround? We don't need best_of or use_beam_search
We can reproduce using VLLM's provided benchmark_throughput.py:
This runs ok:
python benchmarks/benchmark_throughput.py --input-len=768 --output-len=256 --model=codellama/CodeLlama-7b-hf --max-model-len=1024 --num-prompts=1 --num-scheduler-steps=2 --n=1
This crashes:
python benchmarks/benchmark_throughput.py --input-len=768 --output-len=256 --model=codellama/CodeLlama-7b-hf --max-model-len=1024 --num-prompts=1 --num-scheduler-steps=2 --n=2
The error I'm getting is:
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper
[rank0]: return func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 1633, in execute_model
[rank0]: output: SamplerOutput = self.model.sample(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 466, in sample
[rank0]: next_tokens = self.sampler(logits, sampling_metadata)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 274, in forward
[rank0]: maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 879, in _sample
[rank0]: return _sample_with_torch(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 826, in _sample_with_torch
[rank0]: sampled_token_ids_tensor[long_sample_indices] = \
[rank0]: RuntimeError: shape mismatch: value tensor of shape [2] cannot be broadcast to indexing result of shape [1, 1]
@comaniac Hi just wondering if someone working on VLLM can provide an update on this. We want to use multi-step scheduler because the throughput is much better for our needs, however we also need to set n > 1. Simply disabling multistep in that case won't work for us. Thanks!
Sorry we're busying with the company event (Ray Summit) until this week. Will try to find some time after the event to look into it. @SolitaryThinker could you also take a look if you got a chance?
@afeldman-nm has a WIP branch for this
@afeldman-nm has a WIP branch for this
Thanks — are you referring to the branch linked above that disables the multi-step scheduler?
[Bugfix] Handle
best_of>1&use_beam_searchby disabling multi-step scheduling. #8637
Yes - to avoid crashing the server.
We are not planning to support both multistep and beam search at the same time. Instead, we are working on rearchitecting vllm to have asynchronous scheduling which will accomplish the same goal as multistep for throughput performance while making it easier to support the other features
however, if you have an idea for how to do this with multistep, feel free to open up a PR
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!