vllm
vllm copied to clipboard
[Core] [Bugfix] Add Input Embeddings
[!NOTE] This PR is just #11684, but rebased onto main and then with pre-commit errors fixed, since it has been some time since @Bryce1010 has updated that PR.
Adds support for passing prompt_embeds to LLM.generate as
llm.generate({"prompt_embeds": input_embeds}, sampling_params)
or
llm.generate(
[{"prompt_embeds": input_embeds} for input_embeds in inputs_embeds], sampling_params
)
this enables use cases when only the embedding layer is finetuned, and have the same model backend support multiple custom tuned embedding layers
FIX https://github.com/vllm-project/vllm/issues/416 FIX https://github.com/vllm-project/vllm/issues/8323 FIX https://github.com/vllm-project/vllm/issues/14621
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
our project need this feature,Would love to see this one merged!
Can you also update this PR with the unit tests in https://github.com/vllm-project/vllm/pull/6869 to see whether this solution works correctly?
Oops I accidentally closed the PR, reopened it now
Hi,
I set inputs_embeds as (num_tokens, embed_dim) and get the following issues. Are there any advice? Thanks.
Hi,
I set
inputs_embedsas(num_tokens, embed_dim)and get the following issues. Are there any advice? Thanks.
@DarkLight1337 I also meet this issue. When my input is (T, V), I get the error: *** IndexError: index 14 is out of bounds for axis 0 with size 14
When my input is (B, T, V), I get the error: RuntimeError: query, key, and positions must have the same number of tokens.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @qthequartermasterman.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes.
The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.
@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes.
The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.
Thanks. I'll try it soon :D
@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes.
The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.
@qthequartermasterman hi, i use the new version to build vllm, and got this error:
model_outputs = self.model.llm.generate({"prompt_embeds": inputs_embeds}, sampling_params) File "/root/new/vllm/vllm/utils.py", line 1131, in inner return fn(*args, **kwargs) File "/root/new/vllm/vllm/entrypoints/llm.py", line 460, in generate self._validate_and_add_requests( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1320, in _validate_and_add_requests self._add_request( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1338, in _add_request self.llm_engine.add_request( File "/root/new/vllm/vllm/v1/engine/llm_engine.py", line 186, in add_request request = self.processor.process_inputs(request_id, prompt, params, File "/root/new/vllm/vllm/v1/engine/processor.py", line 201, in process_inputs processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( File "/root/new/vllm/vllm/inputs/preprocess.py", line 754, in preprocess return self._process_decoder_only_prompt( File "/root/new/vllm/vllm/inputs/preprocess.py", line 703, in _process_decoder_only_prompt prompt_comps = self._prompt_to_llm_inputs( File "/root/new/vllm/vllm/inputs/preprocess.py", line 341, in _prompt_to_llm_inputs prompt_token_ids = tokens_content["prompt_token_ids"] KeyError: 'prompt_token_ids'
May be we should initialize some fake ids likes
@qthequartermasterman Thanks for your help. The index range error is gone. But the outputted text is still gibberish.
@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes. The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.
@qthequartermasterman hi, i use the new version to build vllm, and got this error:
model_outputs = self.model.llm.generate({"prompt_embeds": inputs_embeds}, sampling_params) File "/root/new/vllm/vllm/utils.py", line 1131, in inner return fn(*args, **kwargs) File "/root/new/vllm/vllm/entrypoints/llm.py", line 460, in generate self._validate_and_add_requests( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1320, in _validate_and_add_requests self._add_request( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1338, in _add_request self.llm_engine.add_request( File "/root/new/vllm/vllm/v1/engine/llm_engine.py", line 186, in add_request request = self.processor.process_inputs(request_id, prompt, params, File "/root/new/vllm/vllm/v1/engine/processor.py", line 201, in process_inputs processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( File "/root/new/vllm/vllm/inputs/preprocess.py", line 754, in preprocess return self._process_decoder_only_prompt( File "/root/new/vllm/vllm/inputs/preprocess.py", line 703, in _process_decoder_only_prompt prompt_comps = self._prompt_to_llm_inputs( File "/root/new/vllm/vllm/inputs/preprocess.py", line 341, in _prompt_to_llm_inputs prompt_token_ids = tokens_content["prompt_token_ids"] KeyError: 'prompt_token_ids'May be we should initialize some fake ids likes ? BTW, i use qwen2.5 1.5B instruct model as my LLM backbone.
I try to initialize prompt_token_ids like this:
The error is gone but the decode tokens is incorrect.
@DarkLight1337 Two questions since you have a greater understanding of the vLLM internals than I do. I have been debugging most of the day, and don't have great answers to them. Thanks for your help so far on this PR!
(1) Do you have any idea what might be causing the inputs_embeds to be populated in the model_input on the first model forward, but to be empty on subsequent forward passes? The reason why the output is garbage when using prompt embeds after the first token is because inputs_embeds doesn't seem to be set after the first pass, and so it's using [0, 0, 0, 0, 344] (for example, where 0 is the placeholder token id for a token that was provided via embeds and the 344 was the first generated token) as the input_ids to generate against.
(2) Do you have any thoughts on the best way of handling a mixture of inputs where some sequences are provided via token_ids and others via inputs_embeds? Right now many of the models (for example opt) completely ignore input ids if inputs_embeds are provided. I know #6869 modified all the model definitions to handle these mixed inputs, but I'm wondering if you can think of a simpler way of doing it without having to modify model definitions? My current thought is that during ModelInputForGPUBuilder.build, if there are any inputs_embeds for any of the inter_data, we could compute the appropriate inputs_embeds for the provided input_tokens. This feels like an odd place to do those computations, however, and I'm not sure how this would interact with lora or mrope.
(1) Do you have any idea what might be causing the inputs_embeds to be populated in the model_input on the first model forward, but to be empty on subsequent forward passes? The reason why the output is garbage when using prompt embeds after the first token is because inputs_embeds doesn't seem to be set after the first pass, and so it's using [0, 0, 0, 0, 344] (for example, where 0 is the placeholder token id for a token that was provided via embeds and the 344 was the first generated token) as the input_ids to generate against.
Not sure. @WoosukKwon might have a better idea of this.
(2) Do you have any thoughts on the best way of handling a mixture of inputs where some sequences are provided via token_ids and others via inputs_embeds? Right now many of the models (for example opt) completely ignore input ids if inputs_embeds are provided. I know https://github.com/vllm-project/vllm/pull/6869 modified all the model definitions to handle these mixed inputs, but I'm wondering if you can think of a simpler way of doing it without having to modify model definitions? My current thought is that during ModelInputForGPUBuilder.build, if there are any inputs_embeds for any of the inter_data, we could compute the appropriate inputs_embeds for the provided input_tokens. This feels like an odd place to do those computations, however, and I'm not sure how this would interact with lora or mrope.
You can group the batches based on whether there are prompt_embeds and pass each group one at a time to the model, then combine the outputs back together. But honestly I think this should be done at the scheduler level so each batch doesn't have mixed groups like this.
Hi, any update on this PR?
Can you merge from main? Looks like basic correctness test is failing
Also, are your previous tests able to catch the problem with batching? If not can you add a test to that file to avoid regressions?
Do you have any idea to solve the gibberish output? Observed from my test, it may caused by the decoding phase of VLLM.
@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes. The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.
@qthequartermasterman hi, i use the new version to build vllm, and got this error:
model_outputs = self.model.llm.generate({"prompt_embeds": inputs_embeds}, sampling_params) File "/root/new/vllm/vllm/utils.py", line 1131, in inner return fn(*args, **kwargs) File "/root/new/vllm/vllm/entrypoints/llm.py", line 460, in generate self._validate_and_add_requests( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1320, in _validate_and_add_requests self._add_request( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1338, in _add_request self.llm_engine.add_request( File "/root/new/vllm/vllm/v1/engine/llm_engine.py", line 186, in add_request request = self.processor.process_inputs(request_id, prompt, params, File "/root/new/vllm/vllm/v1/engine/processor.py", line 201, in process_inputs processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( File "/root/new/vllm/vllm/inputs/preprocess.py", line 754, in preprocess return self._process_decoder_only_prompt( File "/root/new/vllm/vllm/inputs/preprocess.py", line 703, in _process_decoder_only_prompt prompt_comps = self._prompt_to_llm_inputs( File "/root/new/vllm/vllm/inputs/preprocess.py", line 341, in _prompt_to_llm_inputs prompt_token_ids = tokens_content["prompt_token_ids"] KeyError: 'prompt_token_ids'May be we should initialize some fake ids likes ? BTW, i use qwen2.5 1.5B instruct model as my LLM backbone.
I also tried this and got the same error. Is there a better way to solve it?
Do you have any idea to solve the gibberish output? Observed from my test, it may caused by the decoding phase of VLLM.
figured out what happened. Within _compute_lens I had added a conditional where if len(seq_data.prompt_embeds[context_len:seq_len]) == 0 , to replace it with None. This was an incorrect “fix” for another bug which suppressed the complaints about length mismatches within the forward pass. I had misunderstood what was wrong before. The real issue actually came from the fact that none of the new inputs_embeds corresponding to newly generated token ids; by replacing empty tensors with None, I painted over the deeper problem. After consulting with @DarkLight1337, I'm waiting on some input from @WoosukKwon on the best way to propagate embeddings for newly generated tokens.
This PR (https://github.com/vllm-project/vllm/pull/6869/files) runs successfully and maintains accuracy. However, when I try to merge it into the current main branch, it fails to run (getting stuck in an infinite loop during the engine step). Maybe we can take some inspiration from it to see if it helps resolve the current decoding issue? @qthequartermasterman @DarkLight1337
@lzl-mt I have been using it as inspiration after taking over #11684. I'm in contact with the original dev of #6869 (@Nan2018). That PR avoided having to propagate new inputs embeds for newly generated tokens by modifying the model definition for every supported model. Right now most (all?) models completely ignore input_token_ids if inputs_embeds exist. That PR rewrote model definitions to create the embeddings for tokens that were missing inputs_embeds. It would be nice to avoid having to reproduce that, and instead have a more general solution.
Do you have any idea to solve the gibberish output? Observed from my test, it may caused by the decoding phase of VLLM.
figured out what happened. Within _compute_lens I had added a conditional where if len(seq_data.prompt_embeds[context_len:seq_len]) == 0 , to replace it with None. This was an incorrect “fix” for another bug which suppressed the complaints about length mismatches within the forward pass. I had misunderstood what was wrong before. The real issue actually came from the fact that none of the new inputs_embeds corresponding to newly generated token ids; by replacing empty tensors with None, I painted over the deeper problem. After consulting with @DarkLight1337, I'm waiting on some input from @WoosukKwon on the best way to propagate embeddings for newly generated tokens.
This PR (https://github.com/vllm-project/vllm/pull/6869/files) runs successfully and maintains accuracy. However, when I try to merge it into the current main branch, it fails to run (getting stuck in an infinite loop during the engine step). Maybe we can take some inspiration from it to see if it helps resolve the current decoding issue? @qthequartermasterman @DarkLight1337
@lzl-mt I have been using it as inspiration after taking over #11684. I'm in contact with the original dev of #6869 (@Nan2018). That PR avoided having to propagate new inputs embeds for newly generated tokens by modifying the model definition for every supported model. Right now most (all?) models completely ignore input_token_ids if inputs_embeds exist. That PR rewrote model definitions to create the embeddings for tokens that were missing inputs_embeds. It would be nice to avoid having to reproduce that, and instead have a more general solution.
@qthequartermasterman Has the issue with output precision been resolved, or is there an estimated time for when it will be?
Has the issue with output precision been resolved, or is there an estimated time for when it will be?
@lzl-mt It hasn't been fully resolved yet, and it will be atleast a day before I can work on it again. I will be traveling today. I plan on spending time tomorrow to finish resolving. I'm hoping nothing else surprising shows up tomorrow while in the code.
@lzl-mt @SnowCharmQ As an unrelated matter of pure curiosity, what use cases are y'all hoping to use this feature for?
Has the issue with output precision been resolved, or is there an estimated time for when it will be?
@lzl-mt It hasn't been fully resolved yet, and it will be atleast a day before I can work on it again. I will be traveling today. I plan on spending time tomorrow to finish resolving. I'm hoping nothing else surprising shows up tomorrow while in the code.
@lzl-mt @SnowCharmQ As an unrelated matter of pure curiosity, what use cases are y'all hoping to use this feature for?
For me, I want to modify the inputs_embeds to plug some task-relevant information. This can be served as some implicit information rather than explicit prompts.
Has the issue with output precision been resolved, or is there an estimated time for when it will be?
@lzl-mt It hasn't been fully resolved yet, and it will be atleast a day before I can work on it again. I will be traveling today. I plan on spending time tomorrow to finish resolving. I'm hoping nothing else surprising shows up tomorrow while in the code.
@lzl-mt @SnowCharmQ As an unrelated matter of pure curiosity, what use cases are y'all hoping to use this feature for?
We are using this model MooER deploying services related to large-scale audio understanding. We hope to use this feature to accelerate the inference engine.
Sorry. I am using my own preprocess method. Skipped this.
Get Outlook for iOShttps://aka.ms/o0ukef
From: ZhenlinLiang @.> Sent: Thursday, April 10, 2025 4:59:51 PM To: vllm-project/vllm @.> Cc: @.*** @.>; Mention @.> Subject: Re: [vllm-project/vllm] [Core] [Bugfix] Add Input Embeddings (PR #15428)
This is an external email.
@lzl-mt commented on this pull request.
In vllm/worker/model_runner.pyhttps://github.com/vllm-project/vllm/pull/15428#discussion_r2036866767:
@@ -511,13 +542,24 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_data.get_num_computed_tokens()
# Compute tokens.
-
tokens = seq_data.get_token_ids()[context_len:seq_len]
-
if seq_data.prompt_embeds is None:
@Aaronhuang-778https://github.com/Aaronhuang-778 BTW, do u meet this error? vllm/vllm/inputs/preprocess.py", line 341, in _prompt_to_llm_inputs prompt_token_ids = tokens_content["prompt_token_ids"] KeyError: 'prompt_token_id
— Reply to this email directly, view it on GitHubhttps://github.com/vllm-project/vllm/pull/15428#discussion_r2036866767, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AQ7HCHGFOUO7YVBSTKUSOCD2YYXIPAVCNFSM6AAAAABZWLXWKGVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDONJVHEYTCNZYGM. You are receiving this because you were mentioned.Message ID: @.***>
@qthequartermasterman Here is a simple test:
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModel
import torch
def format_qwen_prompt(user_input: str) -> str:
return (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{user_input}<|im_end|>\n"
"<|im_start|>assistant\n"
)
def get_token_embeddings(prompt: str, model_path: str):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
# 拿到 input_ids
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
print(tokenizer)
input_ids = inputs["input_ids"] # shape: [1, seq_len]
# 从 embedding table 查 embedding
with torch.no_grad():
embeddings = model.embed_tokens(input_ids)
return embeddings
def main():
model_path = "/jfs/zhenlin.liang/model/Qwen2.5-1.5B-Instruct/"
prompt = format_qwen_prompt("法国的首都是哪里?")
embeddings = get_token_embeddings(prompt, model_path)
print("Embedding shape:", embeddings.shape) # [seq_len, hidden_size]
llm = LLM(
model="/jfs/zhenlin.liang/model/Qwen2.5-1.5B-Instruct/",
gpu_memory_utilization=0.7,
max_model_len=128,
enforce_eager=True,
dtype="bfloat16"
)
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.95,
max_tokens=50
)
outputs = llm.generate({"prompt_embeds": embeddings}, sampling_params)
print(outputs)
if __name__ == '__main__':
main()
The latest version(commit e7ab2a2 && use VLLM_USE_V1=0 solve my gibberish output, thanks a lot
Quick update since I haven't commented in a few days. VLLM_USE_V1=0 and enforce_eager=True works as expected for generation with inputs_embeds. I have updated all the unit tests (including adding some for the scheduler to address batching with both inputs_embeds and input_ids). The basic correctness tests are still failing because this change is currently only working in eager mode. I plan on looking into that tomorrow.
@DarkLight1337 I believe this PR is ready for review. All the tests that are available are passing. I have confirmed that generation with prompt embeds is working in both eager and compiled modes.
@lzl-mt You shouldn't see any difference in behavior, but can you please verify the latest commit is working for you? All of my personal test scripts (similar to yours) are working perfectly.
Overall looks ok now, other than the pre-commit errors. Also I would like to merge this after https://github.com/vllm-project/vllm/pull/15686 which will remove SingletonInputsAdapter
