vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Core] [Bugfix] Add Input Embeddings

Open qthequartermasterman opened this issue 8 months ago • 47 comments

[!NOTE] This PR is just #11684, but rebased onto main and then with pre-commit errors fixed, since it has been some time since @Bryce1010 has updated that PR.

Adds support for passing prompt_embeds to LLM.generate as

llm.generate({"prompt_embeds": input_embeds}, sampling_params)

or

llm.generate(
    [{"prompt_embeds": input_embeds} for input_embeds in inputs_embeds], sampling_params
)

this enables use cases when only the embedding layer is finetuned, and have the same model backend support multiple custom tuned embedding layers

FIX https://github.com/vllm-project/vllm/issues/416 FIX https://github.com/vllm-project/vllm/issues/8323 FIX https://github.com/vllm-project/vllm/issues/14621

qthequartermasterman avatar Mar 25 '25 02:03 qthequartermasterman

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Mar 25 '25 02:03 github-actions[bot]

our project need this feature,Would love to see this one merged!

liangwythu avatar Mar 25 '25 03:03 liangwythu

Can you also update this PR with the unit tests in https://github.com/vllm-project/vllm/pull/6869 to see whether this solution works correctly?

DarkLight1337 avatar Mar 25 '25 04:03 DarkLight1337

Oops I accidentally closed the PR, reopened it now

DarkLight1337 avatar Mar 26 '25 03:03 DarkLight1337

Hi,

I set inputs_embeds as (num_tokens, embed_dim) and get the following issues. Are there any advice? Thanks.

image

yukang2017 avatar Mar 27 '25 03:03 yukang2017

Hi,

I set inputs_embeds as (num_tokens, embed_dim) and get the following issues. Are there any advice? Thanks.

image

@DarkLight1337 I also meet this issue. When my input is (T, V), I get the error: *** IndexError: index 14 is out of bounds for axis 0 with size 14

When my input is (B, T, V), I get the error: RuntimeError: query, key, and positions must have the same number of tokens.

lzl-mt avatar Mar 27 '25 12:03 lzl-mt

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @qthequartermasterman.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 01 '25 08:04 mergify[bot]

@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes.

The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.

qthequartermasterman avatar Apr 02 '25 21:04 qthequartermasterman

@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes.

The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.

Thanks. I'll try it soon :D

lzl-mt avatar Apr 03 '25 03:04 lzl-mt

@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes.

The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.

@qthequartermasterman hi, i use the new version to build vllm, and got this error: model_outputs = self.model.llm.generate({"prompt_embeds": inputs_embeds}, sampling_params) File "/root/new/vllm/vllm/utils.py", line 1131, in inner return fn(*args, **kwargs) File "/root/new/vllm/vllm/entrypoints/llm.py", line 460, in generate self._validate_and_add_requests( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1320, in _validate_and_add_requests self._add_request( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1338, in _add_request self.llm_engine.add_request( File "/root/new/vllm/vllm/v1/engine/llm_engine.py", line 186, in add_request request = self.processor.process_inputs(request_id, prompt, params, File "/root/new/vllm/vllm/v1/engine/processor.py", line 201, in process_inputs processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( File "/root/new/vllm/vllm/inputs/preprocess.py", line 754, in preprocess return self._process_decoder_only_prompt( File "/root/new/vllm/vllm/inputs/preprocess.py", line 703, in _process_decoder_only_prompt prompt_comps = self._prompt_to_llm_inputs( File "/root/new/vllm/vllm/inputs/preprocess.py", line 341, in _prompt_to_llm_inputs prompt_token_ids = tokens_content["prompt_token_ids"] KeyError: 'prompt_token_ids'

May be we should initialize some fake ids likes ? BTW, i use qwen2.5 1.5B instruct model as my LLM backbone.

lzl-mt avatar Apr 03 '25 11:04 lzl-mt

@qthequartermasterman Thanks for your help. The index range error is gone. But the outputted text is still gibberish.

yukang2017 avatar Apr 03 '25 12:04 yukang2017

@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes. The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.

@qthequartermasterman hi, i use the new version to build vllm, and got this error: model_outputs = self.model.llm.generate({"prompt_embeds": inputs_embeds}, sampling_params) File "/root/new/vllm/vllm/utils.py", line 1131, in inner return fn(*args, **kwargs) File "/root/new/vllm/vllm/entrypoints/llm.py", line 460, in generate self._validate_and_add_requests( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1320, in _validate_and_add_requests self._add_request( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1338, in _add_request self.llm_engine.add_request( File "/root/new/vllm/vllm/v1/engine/llm_engine.py", line 186, in add_request request = self.processor.process_inputs(request_id, prompt, params, File "/root/new/vllm/vllm/v1/engine/processor.py", line 201, in process_inputs processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( File "/root/new/vllm/vllm/inputs/preprocess.py", line 754, in preprocess return self._process_decoder_only_prompt( File "/root/new/vllm/vllm/inputs/preprocess.py", line 703, in _process_decoder_only_prompt prompt_comps = self._prompt_to_llm_inputs( File "/root/new/vllm/vllm/inputs/preprocess.py", line 341, in _prompt_to_llm_inputs prompt_token_ids = tokens_content["prompt_token_ids"] KeyError: 'prompt_token_ids'

May be we should initialize some fake ids likes ? BTW, i use qwen2.5 1.5B instruct model as my LLM backbone.

I try to initialize prompt_token_ids like this: image

The error is gone but the decode tokens is incorrect. image

lzl-mt avatar Apr 03 '25 12:04 lzl-mt

@DarkLight1337 Two questions since you have a greater understanding of the vLLM internals than I do. I have been debugging most of the day, and don't have great answers to them. Thanks for your help so far on this PR!

(1) Do you have any idea what might be causing the inputs_embeds to be populated in the model_input on the first model forward, but to be empty on subsequent forward passes? The reason why the output is garbage when using prompt embeds after the first token is because inputs_embeds doesn't seem to be set after the first pass, and so it's using [0, 0, 0, 0, 344] (for example, where 0 is the placeholder token id for a token that was provided via embeds and the 344 was the first generated token) as the input_ids to generate against.

(2) Do you have any thoughts on the best way of handling a mixture of inputs where some sequences are provided via token_ids and others via inputs_embeds? Right now many of the models (for example opt) completely ignore input ids if inputs_embeds are provided. I know #6869 modified all the model definitions to handle these mixed inputs, but I'm wondering if you can think of a simpler way of doing it without having to modify model definitions? My current thought is that during ModelInputForGPUBuilder.build, if there are any inputs_embeds for any of the inter_data, we could compute the appropriate inputs_embeds for the provided input_tokens. This feels like an odd place to do those computations, however, and I'm not sure how this would interact with lora or mrope.

qthequartermasterman avatar Apr 04 '25 01:04 qthequartermasterman

(1) Do you have any idea what might be causing the inputs_embeds to be populated in the model_input on the first model forward, but to be empty on subsequent forward passes? The reason why the output is garbage when using prompt embeds after the first token is because inputs_embeds doesn't seem to be set after the first pass, and so it's using [0, 0, 0, 0, 344] (for example, where 0 is the placeholder token id for a token that was provided via embeds and the 344 was the first generated token) as the input_ids to generate against.

Not sure. @WoosukKwon might have a better idea of this.

(2) Do you have any thoughts on the best way of handling a mixture of inputs where some sequences are provided via token_ids and others via inputs_embeds? Right now many of the models (for example opt) completely ignore input ids if inputs_embeds are provided. I know https://github.com/vllm-project/vllm/pull/6869 modified all the model definitions to handle these mixed inputs, but I'm wondering if you can think of a simpler way of doing it without having to modify model definitions? My current thought is that during ModelInputForGPUBuilder.build, if there are any inputs_embeds for any of the inter_data, we could compute the appropriate inputs_embeds for the provided input_tokens. This feels like an odd place to do those computations, however, and I'm not sure how this would interact with lora or mrope.

You can group the batches based on whether there are prompt_embeds and pass each group one at a time to the model, then combine the outputs back together. But honestly I think this should be done at the scheduler level so each batch doesn't have mixed groups like this.

DarkLight1337 avatar Apr 04 '25 02:04 DarkLight1337

Hi, any update on this PR?

SnowCharmQ avatar Apr 06 '25 03:04 SnowCharmQ

Can you merge from main? Looks like basic correctness test is failing

DarkLight1337 avatar Apr 06 '25 04:04 DarkLight1337

Also, are your previous tests able to catch the problem with batching? If not can you add a test to that file to avoid regressions?

DarkLight1337 avatar Apr 06 '25 04:04 DarkLight1337

Do you have any idea to solve the gibberish output? Observed from my test, it may caused by the decoding phase of VLLM.

Aaronhuang-778 avatar Apr 08 '25 07:04 Aaronhuang-778

@yukang2017 @lzl-mt I had accidentally introduced a bug with batching. That part should be fixed now. Can you please verify that you no longer see index out of range errors? If you do, can you please a minimal reproducible example? I cannot reproduce them with these batching fixes. The generation runs end to end now, but there is still a issue where the outputted text is gibberish when using prompt_embeds. It feels to semi-informed eye like some sort of attention mask not being respected. I plan on making some time tomorrow to diagnose that issue.

@qthequartermasterman hi, i use the new version to build vllm, and got this error: model_outputs = self.model.llm.generate({"prompt_embeds": inputs_embeds}, sampling_params) File "/root/new/vllm/vllm/utils.py", line 1131, in inner return fn(*args, **kwargs) File "/root/new/vllm/vllm/entrypoints/llm.py", line 460, in generate self._validate_and_add_requests( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1320, in _validate_and_add_requests self._add_request( File "/root/new/vllm/vllm/entrypoints/llm.py", line 1338, in _add_request self.llm_engine.add_request( File "/root/new/vllm/vllm/v1/engine/llm_engine.py", line 186, in add_request request = self.processor.process_inputs(request_id, prompt, params, File "/root/new/vllm/vllm/v1/engine/processor.py", line 201, in process_inputs processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( File "/root/new/vllm/vllm/inputs/preprocess.py", line 754, in preprocess return self._process_decoder_only_prompt( File "/root/new/vllm/vllm/inputs/preprocess.py", line 703, in _process_decoder_only_prompt prompt_comps = self._prompt_to_llm_inputs( File "/root/new/vllm/vllm/inputs/preprocess.py", line 341, in _prompt_to_llm_inputs prompt_token_ids = tokens_content["prompt_token_ids"] KeyError: 'prompt_token_ids'

May be we should initialize some fake ids likes ? BTW, i use qwen2.5 1.5B instruct model as my LLM backbone.

I also tried this and got the same error. Is there a better way to solve it?

SnowCharmQ avatar Apr 08 '25 07:04 SnowCharmQ

Do you have any idea to solve the gibberish output? Observed from my test, it may caused by the decoding phase of VLLM.

figured out what happened. Within _compute_lens I had added a conditional where if len(seq_data.prompt_embeds[context_len:seq_len]) == 0 , to replace it with None. This was an incorrect “fix” for another bug which suppressed the complaints about length mismatches within the forward pass. I had misunderstood what was wrong before. The real issue actually came from the fact that none of the new inputs_embeds corresponding to newly generated token ids; by replacing empty tensors with None, I painted over the deeper problem. After consulting with @DarkLight1337, I'm waiting on some input from @WoosukKwon on the best way to propagate embeddings for newly generated tokens.

This PR (https://github.com/vllm-project/vllm/pull/6869/files) runs successfully and maintains accuracy. However, when I try to merge it into the current main branch, it fails to run (getting stuck in an infinite loop during the engine step). Maybe we can take some inspiration from it to see if it helps resolve the current decoding issue? @qthequartermasterman @DarkLight1337

@lzl-mt I have been using it as inspiration after taking over #11684. I'm in contact with the original dev of #6869 (@Nan2018). That PR avoided having to propagate new inputs embeds for newly generated tokens by modifying the model definition for every supported model. Right now most (all?) models completely ignore input_token_ids if inputs_embeds exist. That PR rewrote model definitions to create the embeddings for tokens that were missing inputs_embeds. It would be nice to avoid having to reproduce that, and instead have a more general solution.

qthequartermasterman avatar Apr 08 '25 14:04 qthequartermasterman

Do you have any idea to solve the gibberish output? Observed from my test, it may caused by the decoding phase of VLLM.

figured out what happened. Within _compute_lens I had added a conditional where if len(seq_data.prompt_embeds[context_len:seq_len]) == 0 , to replace it with None. This was an incorrect “fix” for another bug which suppressed the complaints about length mismatches within the forward pass. I had misunderstood what was wrong before. The real issue actually came from the fact that none of the new inputs_embeds corresponding to newly generated token ids; by replacing empty tensors with None, I painted over the deeper problem. After consulting with @DarkLight1337, I'm waiting on some input from @WoosukKwon on the best way to propagate embeddings for newly generated tokens.

This PR (https://github.com/vllm-project/vllm/pull/6869/files) runs successfully and maintains accuracy. However, when I try to merge it into the current main branch, it fails to run (getting stuck in an infinite loop during the engine step). Maybe we can take some inspiration from it to see if it helps resolve the current decoding issue? @qthequartermasterman @DarkLight1337

@lzl-mt I have been using it as inspiration after taking over #11684. I'm in contact with the original dev of #6869 (@Nan2018). That PR avoided having to propagate new inputs embeds for newly generated tokens by modifying the model definition for every supported model. Right now most (all?) models completely ignore input_token_ids if inputs_embeds exist. That PR rewrote model definitions to create the embeddings for tokens that were missing inputs_embeds. It would be nice to avoid having to reproduce that, and instead have a more general solution.

@qthequartermasterman Has the issue with output precision been resolved, or is there an estimated time for when it will be?

lzl-mt avatar Apr 09 '25 06:04 lzl-mt

Has the issue with output precision been resolved, or is there an estimated time for when it will be?

@lzl-mt It hasn't been fully resolved yet, and it will be atleast a day before I can work on it again. I will be traveling today. I plan on spending time tomorrow to finish resolving. I'm hoping nothing else surprising shows up tomorrow while in the code.

@lzl-mt @SnowCharmQ As an unrelated matter of pure curiosity, what use cases are y'all hoping to use this feature for?

qthequartermasterman avatar Apr 09 '25 10:04 qthequartermasterman

Has the issue with output precision been resolved, or is there an estimated time for when it will be?

@lzl-mt It hasn't been fully resolved yet, and it will be atleast a day before I can work on it again. I will be traveling today. I plan on spending time tomorrow to finish resolving. I'm hoping nothing else surprising shows up tomorrow while in the code.

@lzl-mt @SnowCharmQ As an unrelated matter of pure curiosity, what use cases are y'all hoping to use this feature for?

For me, I want to modify the inputs_embeds to plug some task-relevant information. This can be served as some implicit information rather than explicit prompts.

SnowCharmQ avatar Apr 09 '25 12:04 SnowCharmQ

Has the issue with output precision been resolved, or is there an estimated time for when it will be?

@lzl-mt It hasn't been fully resolved yet, and it will be atleast a day before I can work on it again. I will be traveling today. I plan on spending time tomorrow to finish resolving. I'm hoping nothing else surprising shows up tomorrow while in the code.

@lzl-mt @SnowCharmQ As an unrelated matter of pure curiosity, what use cases are y'all hoping to use this feature for?

We are using this model MooER deploying services related to large-scale audio understanding. We hope to use this feature to accelerate the inference engine.

lzl-mt avatar Apr 10 '25 06:04 lzl-mt

Sorry. I am using my own preprocess method. Skipped this.

Get Outlook for iOShttps://aka.ms/o0ukef


From: ZhenlinLiang @.> Sent: Thursday, April 10, 2025 4:59:51 PM To: vllm-project/vllm @.> Cc: @.*** @.>; Mention @.> Subject: Re: [vllm-project/vllm] [Core] [Bugfix] Add Input Embeddings (PR #15428)

This is an external email.

@lzl-mt commented on this pull request.


In vllm/worker/model_runner.pyhttps://github.com/vllm-project/vllm/pull/15428#discussion_r2036866767:

@@ -511,13 +542,24 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, context_len = seq_data.get_num_computed_tokens()

     # Compute tokens.
  •    tokens = seq_data.get_token_ids()[context_len:seq_len]
    
  •    if seq_data.prompt_embeds is None:
    

@Aaronhuang-778https://github.com/Aaronhuang-778 BTW, do u meet this error? vllm/vllm/inputs/preprocess.py", line 341, in _prompt_to_llm_inputs prompt_token_ids = tokens_content["prompt_token_ids"] KeyError: 'prompt_token_id

— Reply to this email directly, view it on GitHubhttps://github.com/vllm-project/vllm/pull/15428#discussion_r2036866767, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AQ7HCHGFOUO7YVBSTKUSOCD2YYXIPAVCNFSM6AAAAABZWLXWKGVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDONJVHEYTCNZYGM. You are receiving this because you were mentioned.Message ID: @.***>

Aaronhuang-778 avatar Apr 10 '25 09:04 Aaronhuang-778

@qthequartermasterman Here is a simple test:

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModel
import torch

def format_qwen_prompt(user_input: str) -> str:
    return (
        "<|im_start|>system\n"
        "You are a helpful assistant.<|im_end|>\n"
        "<|im_start|>user\n"
        f"{user_input}<|im_end|>\n"
        "<|im_start|>assistant\n"
    )

def get_token_embeddings(prompt: str, model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
    model.eval()

    # 拿到 input_ids
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    print(tokenizer)
    input_ids = inputs["input_ids"]  # shape: [1, seq_len]

    # 从 embedding table 查 embedding
    with torch.no_grad():
        embeddings = model.embed_tokens(input_ids)

    return embeddings

def main():
    model_path = "/jfs/zhenlin.liang/model/Qwen2.5-1.5B-Instruct/"
    prompt = format_qwen_prompt("法国的首都是哪里?")

    embeddings = get_token_embeddings(prompt, model_path)

    print("Embedding shape:", embeddings.shape)  # [seq_len, hidden_size]
    llm = LLM(
        model="/jfs/zhenlin.liang/model/Qwen2.5-1.5B-Instruct/",
        gpu_memory_utilization=0.7,
        max_model_len=128,
        enforce_eager=True,
        dtype="bfloat16"
    )

    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.95,
        max_tokens=50
    )
    outputs = llm.generate({"prompt_embeds": embeddings}, sampling_params)
    print(outputs)

if __name__ == '__main__':
    main()

lzl-mt avatar Apr 10 '25 13:04 lzl-mt

The latest version(commit e7ab2a2 && use VLLM_USE_V1=0 solve my gibberish output, thanks a lot

lzl-mt avatar Apr 11 '25 08:04 lzl-mt

Quick update since I haven't commented in a few days. VLLM_USE_V1=0 and enforce_eager=True works as expected for generation with inputs_embeds. I have updated all the unit tests (including adding some for the scheduler to address batching with both inputs_embeds and input_ids). The basic correctness tests are still failing because this change is currently only working in eager mode. I plan on looking into that tomorrow.

qthequartermasterman avatar Apr 16 '25 03:04 qthequartermasterman

@DarkLight1337 I believe this PR is ready for review. All the tests that are available are passing. I have confirmed that generation with prompt embeds is working in both eager and compiled modes.

@lzl-mt You shouldn't see any difference in behavior, but can you please verify the latest commit is working for you? All of my personal test scripts (similar to yours) are working perfectly.

qthequartermasterman avatar Apr 16 '25 19:04 qthequartermasterman

Overall looks ok now, other than the pre-commit errors. Also I would like to merge this after https://github.com/vllm-project/vllm/pull/15686 which will remove SingletonInputsAdapter

DarkLight1337 avatar Apr 21 '25 14:04 DarkLight1337