vllm
vllm copied to clipboard
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API
Key Features of This Pull Request
This PR adds the e5-mistral-7b-instruct model and enables an E2E embedding vector generation.
There are a few open PRs that add support for embedding models. Our PR uniquely addresses the following key issues:
Integration of the OpenAI Server Front End
This PR includes comprehensive end-to-end (E2E) embedding functionality, spanning from the OpenAI Front End to the Back End.
Turn off KV cache with embedding models
This PR introduces the capability to turn off KV cache when operating in embedding mode, which includes the block_tables and cache_engine.
High Level Design
The embedding model can essentially be considered a special type of generative model with max_token=1 from an inference perspective. Both embedding and generative models (with max_token=1) require a single feedforward calculation without the need for generating any subsequent new tokens. The primary differences are:
The embedding model returns the hidden state, while the generative model takes an extra step to sample and return the first output token. In this sense, the embedding model is a subset of the generative model in terms of the calculations performed during the single feedforward operation on GPUs. They differ in their API specifications concerning request and output formats.
As a consequence, the serving of embedding models is:
- Independent of any technologies implemented for subsequent token-by-token generations; and
- They share the same technologies as generative models for prompt processing.
Therefore, in this PR, our current implementation focuses on bypassing vLLM’s components for subsequent token generations, such as KV cache and CUDA Graph to avoid unnecessary GPU memory, which in turn is useful to optimize the performance for embedding serving. If these technologies (e.g., CUDA Graph) are enhanced to improve prompt processing, they can be directly applied to embedding models. At a high level, our PR is independent of such crucial feature tensor parallelism, quantization, etc., which are addressed in Milestone 1 below. We have tested it with tensor parallelism > 1 and it functions effectively. We are willing to work together to test/improve these crucial features where needed.
Benchmarking Results
On 1 H100 GPU, vLLM’s serving of the E5 embedding model reaches a max throughput up to 36.6k tokens/s and remains consistent across sequence lengths from 128 to 32k (when gpu_memory_utilization
is 0.8).
On 1 A100 GPU, it reaches up to 15.3k tokens/second and also remains consistent across sequence lengths. As a comparison, in a test of the latest ONNX, the speed is up to 12.6k tokens/sec when sequence length is low (256 tokens), is reduced to 8k tokens/sec when sequence length is 2k, and shows a general trend of getting worse when the sequent length is longer.
See the figure for more details of our testing results on H100 and A100
Note: The throughput
is measured as the "number of sequences in a batch" * "sequence length" / "end-to-end latency"
Design and Implementation Details
Add Embedding generation
Add Embedding API to entrypoints/openai
- Add
EmbeddingRequest
andEmbeddingResponse
to protocol.py - Add
serving_embedding.py
- Add
OpenAIServingEmbedding
to api_server - Make
llm.py
work with embedding
Add Embedding in outputs and sequence
- Make an abstract base class of
RequestOutput
andSequenceGroupOutput
. - Add separate
Completion*Output
andEmbedding*Output
. - Add
EmbeddingOutput
,EmbeddingRequestOutput
,RequestOutputFactory
andEmbeddingSequenceGroupOutput
to support processing the embedding output sequence. - Update process output and sequence in *llm_engine to use embedding
Add MistralModel and embedding generation in llama.py
- Add
llama.embedding()
andload_weights
in LlamaModel to support forward and embedding vector generation - Adapted from code examples in https://huggingface.co/intfloat/e5-mistral-7b-instruct
- Mistral uses LlamaModel.
- Use embedding when
embedding_mode
is True in model_runner - Add
load_weights
inllama.py
to support embedding models
Disable KV Cache for Embedding serving
Skip slot_mapping
with embedding mode
-
slot_mapping
is only used in model_executor/layers/attention.py when kv_cache is not None. In embedding mode, we pass None kv_cache. So no need to process slot_mapping
Turn off KV cache for embedding mode
The goal is to disable the block_table and cache_engine completely, so we don't consider allocating blocks for KV cache for embedding mode
- Add
embedding_mode
toModelConfig
andSchedulerConfig
- Add a
BlockSpaceManagerProxy
to control the block management behavior for embedding mode
Update parameters for max batching
- Add
profile_max_batched_tokens_for_embedding
to profile the max_num_batched_tokens for embedding server mode - Return max_batch_size as each Ray worker runs the profiling once
- Enable embedding profiling in
ray_gpu_executor.py
to support tensor parallelism > 1
Notes
Overlaps with other PRs
This PR overlaps with the following Milestones in Supporting embedding models #3187
Milestone 1: Basic functionality using Hugging Face models [Partially completed]
Note: Instead of using LLM.encode()
, this PR currently adds embedding()
to LlamaModel
, and then keeps LLM.generate()
for embedding.
G) Update parameters for max batching [Completed]
We introduced profile_max_batched_tokens_for_embedding
in gpu_executor.py
to support the maximum number of tokens the GPU can take in one batch.
Milestone 2: Implement the model properly [Completed]
This PR focuses on adding the e5-mistral-7b-instruct model, which can utilize llama.py. So it already uses the vLLM layer primitives.
Milestone 4: Add the OpenAI Server Front End [Completed]
This PR has implemented the Embedding API to entrypoints/openai
.
Discrepancies with other discussions
This PR didn't implement the following in this discussion in Supporting embedding models #3187
Move finish sequence logic to check_stop
Currently, the logic is in llm_engine._process_model_output()
The logic should be in llm_engine._check_stop
Automatically detect that the model is an embedding model
The user should not have to specify that it is an embedding model Somewhere in the vllm code, create a registry that selects which models are embedding and which models are decoders
F) Update Pass EmbeddingParams
around instead of SamplingParams
note: this is going to require a lot of work
Note: This PR passes SamplingParams()
to the LLMEngine and disabled the use of it in embedding mode. As separating EmbeddingParams
and SamplingParams
requires changes to the UX, it would be easier to discuss and review in a following PR.
Discussion Points and Considerations
Simplifying the Workflow
- Evaluate the possibility of consolidating embedding-related workflows into existing structures with minimal branching logic. This could involve using if-else statements (specifically in
model_runner.py
) or integrating embedding as a subset of generation needs.
Soundness of Profiling
- Evaluate if the current
profile_max_batched_tokens_for_embedding
is sufficient to support the maximum number of tokens in one batch without causing CUDA OOM.
An overview of tasks derived from the pull request discussions:
Immediate changes
core/
- [x] Make a
BlockSpaceManagerV3
- [x] Remove
self.prompt_limit
if branch in scheduler.py
worker/ and executor/
- [x] Remove
profile_max_num_batched_tokens
for embedding - [x] Evaluate to set a hardcoded
batch_size
ormax_num_batched_tokens
(32k as Woosuk and Zhuohan suggested) instead of profiling
models/
- [x] Make a new embedding_models_dict
- [x] Make a llama_embedding.py
- [x] Check how HuggingFace(HF) generates embedding
Separate model_runner and embedding_model_runner
- [x] Implement
execute_model
separately in model_runner and embedding_model_runner
Further evaluation
- Evaluating the possibility of using
config.json
to distinguish whether a model is for embedding or generation, e.g.XForCausalLM
as a generation model. Check edge cases for fine-tuning. - Separate LLM and LLMEmbedding
- Separate SamplingParams and EmbeddingParams
- Evaluate latency on individual request level
HF embeddings
HF's sentence_transformers
provide sentences, texts, and image embeddings. They use an encode()
for computing sentence embeddings. And Pooling
class supports different types of pooling, including lasttoken
, which e5-mistral-7b-instruct
uses.
Adhering to the same design would require further discussion and evaluation of the points mentioned above.
@ywang96 thanks for the initial review!
I have applied the suggestions and resolved some comments.
The current implementation doesn't allow an engine to do generation and embedding at the same time (since the embedding_mode is passed to initialize the engine), so I wonder if it's actually worth the effort to create a separate EmbeddingEngine since a lot of logics in LLMEngine are completely not needed, and this separation could also make a lot of higher level APIs cleaner.
I agree that having a separate EmbeddingEngine has a lot of benefits. Besides making the design cleaner, adding more functionalities in embedding and supporting more embedding models would be easier. Happy to discuss detailed design. It might be good for a following PR as the current is big.
@ywang96 A question related to design:
intfloat/e5-mistral-7b-instruct requires an eos_token_id appended at the end each input. See:
# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
# append eos_token_id to every input_ids
batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
batch_dict = tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt')
I checked a few embedding models, it seems specific to this model. Since input_ids are processed in ModelRunner._prepare_prompt, is there a good place to inject eos_token_id to the end of each input?
cc @simon-mo for plan here
@CatherineSue This is in good shape. We are very close to being ready to merge. @ywang96 and I discussed. There are a few things we want to do for the final implementation, but we want to take an approach of landing something close to the current version and then following up with incremental work.
Here's what we want to do to get ready for merge:
Needed For Merge
- [ ] Tests
- [ ] Clean up
llama_embedding.py
- [ ] Update terminology from
LlamaEmbeddingModel.embedding
toLlamaEmbeddingModel.pooler
- [ ] Log warning from
LLM.generate()
that this API will change for embedding models - [ ] Add checks for incompatible features and fail if so [ we can push this to another PR if difficult ]
Tests
This is a big feature that needs end-to-end testing. I would suggest that we focus in the following areas:
-
Model correctness
: compare vs sentence-transformers or hugging face implementation of the model. Use L2 norm of the differences between the embeddings. This can use the LLM engine. -
Server API correctness
: show that querying the API gets what is expected from the client side. Check out the existing tests for the OpenAI API server for inspiration as to how to write these tests
Clean up llama_embedding.py
We currently re-implement all the llama
layers. I am okay with having a separate file for llama_embedding.py
, but we should import LlamaModel
and use that rather than re-writing all the layers of the model. For example:
from vllm.model_executor.models.llama import LlamaModel
class LlamaEmbeddingModel(nn.Module):
def __init__():
self.model = LlamaModel(**)
def forward(inputs):
return self.model(inputs)
# currently called embedding
def pooler(hidden_states):
# same as current embedding function
Update Terminology from embedding
to pooler
This is to be consistent with HF / sentence-transformers, which use this terminology for translating between the final hidden states and output embedding. See official BERTModel implementation.
So specifically
-
LlamaEmbeddingModel
should usepooler
instead ofembedding
-
EmbeddingModelRunner
should call thepooler
method
Update LLM.generate
API to log a warning if used with the embedding model
- Make a note that this interface is experimental, and that
LLM.encode
will replace it soon
Add Incompatible Feature Guards
The following features are incompatible with embedding models:
- Spec Decode
- Chunked Prefill
- Automatic Prefix Caching
- Neuron / CPU
- Fp8 KV cache
If these features are specified for an embedding model, we should either fail or log a warning. I would be okay with doing this in a follow up PR if it is not straightforward.
I will make a follow up note for
Follow Ups Post Merge
After we merge this initial implementation, we can refactor in the following way:
Replace SamplerXXX
with PoolerXXX
We currently implement EmbeddingModels
using the input (SamplingParams
) and output (SamplerOutput
) classes in a hacked up manner
We will refactor to:
- Swap
SamplingParameters
forPoolerParameters
, which only has the data needed for embedding models - Swap
SamplerOutput
forPoolerOutput
, which only has the data needed for embedding models - Pipe all this info around the various layers of the engine
LLM.encode
Deprecate embedding models from LLM.generate()
. Instead expose LLM.encode()
, which accepts PoolerParams
and returns EmbeddingRequestOutput
Generic Pooler
Create a generic Pooler()
that corresponds to Sampler()
(currently, we implement pooling logic in LlamaEmbeddingModel.embedding()
rather than in a shared class for each . Pooler could be instantiated with sentence transformer config
This will allow us to support more complex methods like ColBERT, sparse, etc over time.
Refactor LLMEngine
- Refactor
_process_sequence_group_outputs
. - For example, we could have an abstract called
SequenceGroupProcessor
, with subclassesSequenceGroupProcessorEmbedding
SequenceGroupProcessorCompletion
. Each of these is responsible for implementing_process_sequence_group_outputs
usingPoolerOutput
orSamplerOutput
repspectively
@robertgshaw2-neuralmagic thanks for the feedback. Working on resolving the comments and getting the checklist. Will update it soon.
@robertgshaw2-neuralmagic thanks for the feedback. Working on resolving the comments and getting the checklist. Will update it soon.
Thanks @CatherineSue
Apologies for the delay on getting this reviewed and thank you so much for your contribution :)
@robertgshaw2-neuralmagic I resolved all the comments. Here's an overview of the tasks checked in the new commits:
- [x] Rename
BlockSpaceManagerV3
toEmbeddingModelBlockSpaceManager
- [x] Use ModelRegistry to check for embedding models
- [x] Tests
- [x] Clean up
llama_embedding.py
- [x] Update terminology from
LlamaEmbeddingModel.embedding
toLlamaEmbeddingModel.pooler
- [x] Replace
SamplerXXX
withPoolerXXX
I addedPooler
,PoolingParams
,PoolerOutput
, andPoolingMetadata
. Note thatPooler
is not following sentence_transformer's config. I didn't have time to finish it. - [x]
LLM.encode
I have separated it fromLLM.generate
- [x] Log warning from LLM.generate() that this API will change for embedding models Since I have separated it so I didn't add warning.
@CatherineSue - awesome!
Are you planning to resolve the merge conflicts?
Should I review now?
@robertgshaw2-neuralmagic I can resolve them if it is easier for you to review. Might take a while, ETA tonight. Does it work for you?
that works, ill review tomorrow
ping me when ready
I can take a first pass too whenever it's ready if @robertgshaw2-neuralmagic doesn't get there before me :)
@robertgshaw2-neuralmagic @ywang96 Just finished rebase
@ywang96 thanks for the quick review. Will start addressing comments after Robert has another look.
Comments above.
Big items:
- The CI is red. Looks like you are missing a dependency. Add it to
requirements-dev.txt
Small items:
- [ ] Revert
LLM.generate
interface change. The interface forLLM.generate
changed (because we formerly hadRequestOutput
and now we exposeCompletionRequestOutput
. This is a breaking change. Its suboptimal, but I think we should haveRequestOutput
forLLM.generate()
andEmbeddingRequestOutput
forLLM.encode()
. This will avoid a breaking change. - [ ] Expand Test Cases. I like the test framework. Can you expand it to make sure there is some batching that gets triggered?
Questions:
- Have you run this though MTEB or anything that can do a holistic of correctness?
- Have you run any benchmarking to see if these changes impact performance of the generative models at all?
In terms of follow ups:
- Are you planning to continue on the roadmap / refactor that I laid out?
Resolved most comments.
The CI is red. Looks like you are missing a dependency. Add it to requirements-dev.txt
Added the dep. Waiting for CI.
- [x] Revert LLM.generate interface change.
The interface for LLM.generate changed (because we formerly had RequestOutput and now we expose CompletionRequestOutput. This is a breaking change. Its suboptimal, but I think we should have RequestOutput for LLM.generate() and EmbeddingRequestOutput for LLM.encode(). This will avoid a breaking change.
I reverted the API. Does this mean I need to revert CompletionRequestOutput
back to RequestOutput
?
- [x] Expand Test Cases. I like the test framework. Can you expand it to make sure there is some batching that gets triggered?
Added test_batch_embedding
in test_openai_server.py
Questions:
Have you run this though MTEB or anything that can do a holistic of correctness?
I tried but LLM.encode()
interface is not consistent with SentenceTransformer
, specifically batch_size
. And it would raise errors in MTEB.
We have performed correctness testing internally with our science team for their use case.
Have you run any benchmarking to see if these changes impact performance of the generative models at all?
No, it is a valuable benchmark to try. I didn't have time to run regression test for the generative models. Is there any regression benchmark in CI?
In terms of follow ups:
Are you planning to continue on the roadmap / refactor that I laid out?
Unfortunately, I don't have bandwidth to continue the future tasks on the roadmap.
@ywang96 @robertgshaw2-neuralmagic I resolved most of the comments. Please see the summary above.
ModelTest keeps failing due to OOM. Is the model size too big?
I tried to set dtype
to half
in test_llama_embedding
. But it didn't seem to help. I ran it alone locally and the test passed.
@CatherineSue shoot me a note when ready!
@robertgshaw2-neuralmagic It is ready. Not sure how to fix Entrypoints Test's ray resource issue or Models Test's OOM error. Suggestions appreciated.
@CatherineSue Ill take a look at getting the CI to pass
@CatherineSue we have release on Friday, so I think we should merge this after
@robertgshaw2-neuralmagic sounds good. Do you plan to take another look at this PR next week then?
I have reviewed in detail and feel good about it, just have been waiting for a clean CI build + havent had a chance to fix it myself
@CatherineSue
https://github.com/CatherineSue/vllm/blob/1c508625a85449c83c8fc1f2f99d78e7035fcbb6/vllm/model_executor/layers/pooler.py#L13
Last token is good, there are some using meaning pooling either. "mean", "max", "cls", "weightedmean", "lasttoken" are all important.
https://github.com/xlang-ai/instructor-embedding/blob/5cca65eb0ed78ab354b086a5386fb2c528809caa/InstructorEmbedding/instructor.py#L68
Another question , this push will only solve llama's embedding problem. https://github.com/vllm-project/vllm/blob/1c508625a85449c83c8fc1f2f99d78e7035fcbb6/vllm/model_executor/models/llama_embedding.py#L29
Support of other model, is a question?
@WuNein
"mean", "max", "cls", "weightedmean", "lasttoken" are all important. Another question , this push will only solve llama's embedding problem.
These are not within the scope of this PR. I believe there will be following PRs to cover these topics.
@CatherineSue For me, return the last hidden state is everything. Pool can be implement separately. You offer a solution that is more difficult to use on other model.
@CatherineSue For me, return the last hidden state is everything. Pool can be implement separately. You offer a solution that is more difficult to use on other model.
We can expand upon this in a follow up PR