[Bug]: Mistral 3.1 Small Image inference is broken on 0.8.4
Your current environment
The output of `python collect_env.py`
Your output of `python collect_env.py` here
🐛 Describe the bug
I tested that vllm==0.8.3 works fine and vllm==0.8.4 fails
Server:
vllm serve nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic --disable-log-requests
Client:
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
model_id = client.models.list().data[0].id
# Text inference
chat_response = client.chat.completions.create(
model=model_id,
messages=[{
"role": "user",
"content": [
{"type": "text", "text": "Who are you?"},
],
}],
)
print("Text Chat completion output:", chat_response.choices[0].message.content)
# Single-image input inference
image_url = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
]
prompt = "What's in this image?"
for img in image_url:
messages=[{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": img}},
],
}]
chat_response = client.chat.completions.create(model=model_id, messages=messages)
print("Single image Chat completion output:", chat_response.choices[0].message.content)
Here is the stacktrace for the failure when the image request is sent to vllm==0.8.4
INFO: 127.0.0.1:52340 - "POST /v1/chat/completions HTTP/1.1" 200 OK
ERROR 04-15 17:28:02 [core.py:387] EngineCore hit an exception: Traceback (most recent call last):
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 380, in run_engine_core
ERROR 04-15 17:28:02 [core.py:387] engine_core.run_busy_loop()
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 402, in run_busy_loop
ERROR 04-15 17:28:02 [core.py:387] self._process_engine_step()
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 431, in _process_engine_step
ERROR 04-15 17:28:02 [core.py:387] outputs = self.step_fn()
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 207, in step
ERROR 04-15 17:28:02 [core.py:387] output = self.model_executor.execute_model(scheduler_output)
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 77, in execute_model
ERROR 04-15 17:28:02 [core.py:387] output = self.collective_rpc("execute_model",
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 04-15 17:28:02 [core.py:387] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/utils.py", line 2378, in run_method
ERROR 04-15 17:28:02 [core.py:387] return func(*args, **kwargs)
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-15 17:28:02 [core.py:387] return func(*args, **kwargs)
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
ERROR 04-15 17:28:02 [core.py:387] output = self.model_runner.execute_model(scheduler_output)
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-15 17:28:02 [core.py:387] return func(*args, **kwargs)
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1002, in execute_model
ERROR 04-15 17:28:02 [core.py:387] self._execute_mm_encoder(scheduler_output)
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 888, in _execute_mm_encoder
ERROR 04-15 17:28:02 [core.py:387] self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
ERROR 04-15 17:28:02 [core.py:387] ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] File "/home/mgoin/venvs/vllm-rel/lib/python3.12/site-packages/vllm/v1/worker/utils.py", line 58, in scatter_mm_placeholders
ERROR 04-15 17:28:02 [core.py:387] placeholders[is_embed] = embeds
ERROR 04-15 17:28:02 [core.py:387] ~~~~~~~~~~~~^^^^^^^^^^
ERROR 04-15 17:28:02 [core.py:387] RuntimeError: shape mismatch: value tensor of shape [1980, 5120] cannot be broadcast to indexing result of shape [7920, 5120]
Before submitting a new issue...
- [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Hmm, I'm getting a different error when trying to run your MRE. The server hangs without emitting errors when processing the second image. I'm using TP=2 though.
cc @WoosukKwon @heheda12345 might this be a problem with scheduler?
Actually, I just checked and it seems to be an issue with inter-process communication:
cc @njhill @p88h
I tried this out with TP=1. I don't get @mgoin's exception, rather it hangs. It appears that the request is in running state in the scheduler but the engine is just spinning with no tokens produced on each step. Needs some deeper investigation but thought I would share that in case it gives any clues @DarkLight1337. I'm not sure yet but it doesn't appear to be IPC related.
p.s. thanks @mgoin for the high quality bug report with simple reproducer and clear info!
The flame graph shows something somewhat similar #16432 is / was addressing, but this is on the receive path (gpu worker->main).
@DarkLight1337 was that on main ?
That might be a different issue altogether, but it might be interesting to set VLLM_MSGPACK_ZERO_COPY_THRESHOLD to some high value and see if it affects the receive path CPU pattern.
@DarkLight1337 with @mgoin's image example, 182 tokens get scheduled in the first forward pass, but none are generated, and so none are scheduled in subsequent forward passes but the request remains in running state.
So it seems to be indeed a problem with the scheduler. I wonder whether it's also related to #16834?
@ywang96 normally handles such matters but he is busy these few weeks. @heheda12345 @WoosukKwon do you have bandwidth to look into this?
I'm not that familiar with the scheduler so I can't help much...
I managed to reproduce scheduler hangs (different multimodal inputs) with mistal-small-3.1 also with vLLM 0.8.3, which also preceded with:
ValueError: Attempted to assign X + Y = Z multimodal tokens to W placeholders
which crashed the worker process, but not vLLM itself which hanged with several running requests without progress.
What seems to stop the reproduction of this error for our case is setting --disable-chunked-mm-input (and use vLLM 0.8.4 as this flag doesn't exist in 0.8.3).
So I suggest you also try this flag
@hibukipanim can you share your engine parameters please? I've tried using suggested flag but still seeing the same error as TS
@hibukipanim can you share your engine parameters please? I've tried using suggested flag but still seeing the same error as TS
@pySilver please note that I didn't try to reproduce OPs issue, but reported on another instance of a hanging scheduler (showing infinitely running requests and not accepting new ones) with mistral-small 3.1. the args I used which helped avoid the error I had above:
--disable-chunked-mm-input --max-num-batched-tokens=3072 --tokenizer_mode mistral --config_format mistral --load_format mistral --max-model-len 128000 --limit-mm-per-prompt=image=10
I haven't got to the root cause yet, but I feel the bug should be in the input processor. In @mgoin's example, the single image maps to 7920 tokens, which does not seem reasonable.
I've been able to find the commit that introduces the issue for Mistral3 processor cc @DarkLight1337: https://github.com/vllm-project/vllm/commit/56d4aefa33f3f8ffaf74d02a8d7eef9523651864
# Same error as the original issue
uv pip uninstall vllm
export VLLM_COMMIT=56d4aefa33f3f8ffaf74d02a8d7eef9523651864
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}
python examples/offline_inference/vision_language.py -m mistral3 --num-prompts 1
# Works (the commit before)
uv pip uninstall vllm
export VLLM_COMMIT=dd143ef54137807fdb8f91b836c5ec6617dfb507
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}
python examples/offline_inference/vision_language.py -m mistral3 --num-prompts 1
Note that you will need to manually update get_patch_size in PixtralHFEncoderInfo due to a bug at the time:
def get_patch_size(self) -> int:
spatial_merge_size = getattr(self.vision_config, "spatial_merge_size",
1)
return (self.vision_config.patch_size * spatial_merge_size)
Ok the hanging issue is also solved by that PR. So it seems that the issue indeed stems from mismatched placeholders. I wonder why the scheduler hangs instead of getting the placeholder assignment error though.... cc @ywang96
Ok the hanging issue is also solved by that PR. So it seems that the issue indeed stems from mismatched placeholders. I wonder why the scheduler hangs instead of getting the placeholder assignment error though.... cc @ywang96
but not vLLM itself which hanged with several running requests without progress.
Sorry I'm late to this issue and glad it's finally resolved. Regarding the server hang, it's possible that the the request got scheduled in the eye of scheduler (because of the length mismatch) but it crashed the multiprocess worker (becasue of TP) at execution time, thus main process never moved to the next loop and think the current request is still "running". I wonder if this is specific to TP>1 only though.
@ywang96 it's not related to TP, as mentioned above, the problem is that the model runner doesn't generate any tokens. This isn't expected by the scheduler and the result is that we remain in the model loop indefinitely (since the request is still in "running" state) with no subsequent tokens being scheduled in the batch.