vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[RFC]: Extending VLLM towards native support of non text-generating models

Open christian-pinto opened this issue 9 months ago • 4 comments

Motivation.

This RFC proposes a set of changes for better supporting non text-generating models, ultimately making vLLM the engine of preference for multimodal input/output models. Also, the target for this is the v1 engine. This is a followup from a previous RFC #11065 ,that led to the merging of PrithviGeospatialMAE model that by piggybacking the embedding model interface, it's the first model in vLLM that generates (raw) images instead of text.

In a nutshell I want to start the discussion about proper support for models that generate output in various modalities.

Specifically, the proposed changes would target:

  • vLLM interface and serving API
  • Generation of output data that is not text
  • Processing of models' input

Proposed Change.

vLLM interface and main loop

The current vLLM entrypoit/interface LLM/LLMEngine/AsyncLLMEngine is clearly targeting language model: 1) a prompt (text/tokens) is always expected, while multi-modal data are optional; 2) the main engine assumes auto-regressive text generation.

Supporting non text-generating models means enabling users on passing the input type required by their models, including text if needed and getting the output generated depending on the model output modality. Also, some models are executing inference in one pass (e.g., the PrithviGeospatialMAE model that I have previously merged) while others might follow an iterative, yet not auto-regressive process (e.g., diffusion models). The current integration of the only non text-generating model in vLLM (PrithviGeospatialMAE) relies on the embedding/pooling abstraction. This has worked so far because we only extract the raw output of the model, while the generation of the output image is done in post-processing outside vLLM. In this new set of changes the focus would be on better integration, enabling the users to feed image/sound/video as input to vLLM and receiving image/sound/video in output.

In light of the above I believe a new interface is required that is not the classic LLM/LLMEngine/LLMAsyncEngine. Something that could be named along the lines of ImageEngine/AsyncImageEngine. This is for the following main two reasons: 1) The primary role of an LLM is that of generating text. 2) The main loop executed might be different than that of a LLM.

The new interface should have the following characteristics:

  • Allow for multimodal data as input with no mandatory text/token IDs prompt
  • Return the appropriate output data format

Open questions:

  • should we have a different interface for each output modality supported?
  • Could we re-use the same main loop in the current vLLM (v1) and do iterative generation and one pass generation piggy-backing on the existing support for auto-regressive models and pooling models respectively? In the first case the entry point changes but the same engine is instantiated, while in the second case, in addition to the interface/entrypoint we also define a new engine class altogether.

Using such new interface to vLLM would be equivalent to the existing one. At a first stage we could imagine a single entrypoint function (generate()) used to trigger the generation of output depending on the model's modality.

Regarding the Serving API this could be extended to support image/video/audio generation as they are already available in the OpenAPI API

Does this sound reasonable to people?

Generation of data that is not text

In the current implementation the only output the vLLM can generate is text. This is achieved through the OutputProcessor class that is invoked in the output_handler task part of the main V1 engine loop.

I propose extending the OutputProcessor capabilities along the same lines of what is done for multi-modal input processing. A new MULTIMODAL_OUTPUT_REGISTRY is defined for users to register their output processor.

This might also require renaming the current MULTIMODAL_REGISTRY to MULTIMODAL_INPUT_REGISTRY.

Since this is a feature to be used only by non-text generating models, existing LLMs are going to be unaffected by this additional output processor. Text-generating models will default to the existing OutputProcessor class that handles de-tokenization of generated tokens.

An example (very) abstract implementation of the new registry and its integration in the existing OutputProcessor class loop could be as below.

class MultimodalOutputRegistry:
    def __init__ (self, ):

    def register_output_processor():
    def create_output_processor()

We then decorate the model main class with

@MULTIMODAL_OUTPUT_REGISTRY.register_processor(MyOutputProcessor)
class MyModel:

During runtime, the AsyncLLM class instantiates an output processor that is then used by the output handler task for generating text. In this case the OutputProcessor class could be extended like in the below example, where the user provided output processor is instantiated at engine init time. The default text processor is used unless described otherwise

class OutputProcessor:
    def __init__ (self, 
        output_registry = MULTIMODAL_OUTPUT_REGISTRY):
        # this returns the output processor registered by the model or the "default" TextOutpuProcessor
        self.output_processor = output_registry.create_output_processor() 

    def process_outputs(self,
        engine_core_outputs: list[EngineCoreOutput],
        engine_core_timestamp: Optional[float] = None,
        iteration_stats: Optional[IterationStats] = None,
    ) -> OutputProcessorOutput:
        request_outputs: list[RequestOutput] = []
        reqs_to_abort: list[str] = []
        for engine_core_output in engine_core_outputs:
            req_id = engine_core_output.request_id
            req_state = self.request_states.get(req_id)
            if req_state is None:
                # Ignore output for already-aborted request.
                continue

            # 1) Compute stats for this iteration.
            self._update_stats_from_output(req_state, engine_core_output,
                                           engine_core_timestamp,
                                           iteration_stats)

            # actual generation of data goes 
            self.output_processor(request_id, engine_core_output)

            ...

In this case the output data can be of any type. As an example an output processor generating images might generate back a file path in case of offline inference. In case of online serving, if we follow the OpenAPI image creation API, the output is translated into an image object (URL or base64 JSON). This is again in the spirit of what is done for pooling models where models are allowed to return anything in PoolingSequenceGroupOutput.

Even though this modification is deep in the model main loop it is not going to negatively affect the inference performance of the existing models since it will not be used for classic LLMs.

Processing of models' input

Right now, any model can register a Processor to parse the multi-modal input data, transform it into raw data (e.g., pixel values) and then feed it to the model together with the input prompt. The usual auto-regressive process kicks in and runs until either the max context length is reached, or a stop token is emitted.

For some models, such as PrithviGeospatialMAE I would like to define a multimodal input processor that takes as input the path to a geotiff file and parses it by splitting it in patches. The patches will then be fed to the model one by one, or in batches, and the process finishes when we run out of patches. This is an iterative process like auto-regression, but we know beforehand how many times to run inference on the model.

I would like this behavior to be triggered starting from the input processor where we identify the amount of data to make inference on and then execute inference multiple times on all the data under the same user request. Keep the output and have the above output processor post-process it.

Process tracking

I would split this in multiple tasks:

  1. Creation of an additional entrypoint to support non language generating models

  2. Adding output processors for non text-generating models

  3. Create input and output processor for PrithviGeospatialMAE The output would still be the raw model output for all the patches we run inference on (i.e., a list of tensors)

  4. Adding support for online serving of multimodal output models

Feedback Period.

2 weeks

CC List.

@DarkLight1337 @ywang96 @njhill @alex-jw-brooks

Any Other Things.

No response

Before submitting a new issue...

  • [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

christian-pinto avatar Apr 04 '25 09:04 christian-pinto

Some updates/thoughts:

  • The hidden states processor RFC (https://github.com/vllm-project/vllm/issues/12249) is quite similar to your idea of output processor, and is planned to be discussed in more detail by the vLLM team since it introduces substantial changes to the core architecture.
    • We already have a PR in progress to add pooling models to V1. It is designed with hidden states processor in mind and thus also warrants further discussion.
  • @ywang96 originally wanted to support multi-modal output via a separate package under vllm-project organization. This fits well with your idea of MULTIMODAL_OUTPUT_REGISTRY which supports OOT registration.
    • We can potentially move out the stuff in omni_llm_engine.py of https://github.com/vllm-project/vllm/pull/16347 into a separate package if we want to take that direction.

DarkLight1337 avatar Apr 09 '25 15:04 DarkLight1337

Some updates/thoughts:

  • The hidden states processor RFC ([RFC]: Hidden states processor #12249) is quite similar to your idea of output processor, and is planned to be discussed in more detail by the vLLM team since it introduces substantial changes to the core architecture.

    • We already have a PR in progress to add pooling models to V1. It is designed with hidden states processor in mind and thus also warrants further discussion.

Exactly, this would be something helpful down the road for supporting multimodal output. In my proposal the hidden states will then be processed by an output processor. vLLM already has an output processor that is only de-tokenizing though.

Right, the best would be for the output processors to be registered per model and I see it working well with the RFC you have mentioned. People would register their processor either per model, or even at runtime as a plugin as @ywang96 suggests and it would process the hidden states.

What about the fact that some models might not be auto-regressive but instead iterative or a "one pass" kind of models. Would you envision them still going through the embeddings path? Or would you be open to having a dedicated engine for handling them.

christian-pinto avatar Apr 09 '25 15:04 christian-pinto

What about the fact that some models might not be auto-regressive but instead iterative or a "one pass" kind of models.

We can treat "one pass" models in the same way as pooling models.

As for iterative models, we can refer to #16347 which reuses prompt_embeds in subsequent iterations.

DarkLight1337 avatar Apr 09 '25 16:04 DarkLight1337

IMO we can focus on the following use cases first:

  • TTS (with streaming)
  • Image segmentation
  • Video segmentation/prediction (with streaming)

DarkLight1337 avatar Apr 09 '25 16:04 DarkLight1337

What about the fact that some models might not be auto-regressive but instead iterative or a "one pass" kind of models.

We can treat "one pass" models in the same way as pooling models.

As for iterative models, we can refer to #16347 which reuses prompt_embeds in subsequent iterations.

Thanks, let me have a look at this.

christian-pinto avatar Apr 14 '25 15:04 christian-pinto

IMO we can focus on the following use cases first:

  • TTS (with streaming)
  • Image segmentation
  • Video segmentation/prediction (with streaming)

Sounds good to me. My immediate use-case will be image segmentation.

christian-pinto avatar Apr 14 '25 15:04 christian-pinto

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

github-actions[bot] avatar Jul 14 '25 02:07 github-actions[bot]

As first steps in this process I have first re-enabled the Prithvi Geospatial models (#20072, #20577) on v1 which also required enabling support for attention free models (#20811).

christian-pinto avatar Jul 14 '25 10:07 christian-pinto

I'm considering this to be completed. Hidden states processor has been already integrated into vLLM after some discussion within the vLLM core group, we have decided that this is where the scope of multimodal output generation stops for vllm-project/vllm.

Happy to discuss offline if you have any questions about it!

ywang96 avatar Oct 07 '25 08:10 ywang96