vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Usage]: How to bypass multimodal processor logic when inputs are already processed

Open yanyongyu opened this issue 9 months ago • 14 comments

Your current environment

vllm 0.7.3
transformers 4.49.0

How would you like to use vllm

I'm writing a custom multimodal model to generate content from text/image/audio inputs. But, in most cases, the generate request contains processed content like this:

TokensPrompt(
    prompt_token_ids=[xxx],  # tensor of token ids
    multi_modal_data={
        "image": {
            "pixel_values": xxx,
            "image_grid_thw": xxx,
        },
        "audio": {
            "audio_values": xxx,
            "audio_attention_masks": xxx,
        },
    },
)

The prompt token ids are already expanded with the feature size of the multimodal content. I'm trying to write a processor modified from qwen2_5_vl and don't know how to bypass processor (including prompt replacements).

I have tested the code in batch size 1, but failed in larger batch size (with enable_chunked_prefill and enforce_eager True). The error occurs when decoding the second token and indicates that input_ids contains length of one image feature, but mm_inputs contains pixel_values of two images.

I'm trying to find a way to completely bypass the processor when inputs are already processed.

Some code reference:

def _field_config(hf_inputs: Mapping[str, torch.Tensor]):
    image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
    image_grid_sizes = image_grid_thw.prod(-1)

    return {
        "pixel_values": MultiModalFieldConfig.flat_from_sizes(
            "image", image_grid_sizes
        ),
        "image_grid_thw": MultiModalFieldConfig.batched("image"),
        "image_embeds": MultiModalFieldConfig.flat_from_sizes(
            "image", image_grid_sizes
        ),
        "audio_values": MultiModalFieldConfig.batched("audio"),
        "audio_attention_masks": MultiModalFieldConfig.batched("audio"),
    }

class CustomMultiModalDataParser(MultiModalDataParser):
    def _parse_image_data(
        self, data: ModalityData[ImageItem]
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="image",
                required_fields={"pixel_values", "image_grid_thw"},
                fields_factory=_field_config,
            )
        return super()._parse_image_data(data)

    def _parse_audio_data(
        self,
        data: ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"audio_values", "audio_attention_masks"},
                fields_factory=_field_config,
            )
        return super()._parse_audio_data(data)

class CustomMultiModalProcessor(BaseMultiModalProcessor[CustomProcessingInfo]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _field_config(hf_inputs)

    def _get_data_parser(self) -> CustomMultiModalDataParser:
        return CustomMultiModalDataParser()

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index
        audio_token_id = hf_config.audio_token_index
        merge_size = hf_config.vision_config.spatial_merge_size

        # trick here

        def get_image_replacement(item_idx: int):
            # only do replace when inputs are not embeddings
            if (
                item := mm_items.get_items(
                    "image", (DictEmbeddingItems, ImageProcessorItems)
                )
            ) and isinstance(item, DictEmbeddingItems):
                return [image_token_id]

            grid_thw = out_mm_kwargs["image_grid_thw"][item_idx]
            assert isinstance(grid_thw, torch.Tensor)

            num_image_tokens = int(grid_thw.prod()) // (merge_size**2)
            return [image_token_id] * num_image_tokens

        def get_audio_replacement(item_idx: int):
            # only do replace when inputs are not embeddings
            if (
                item := mm_items.get_items(
                    "audio", (DictEmbeddingItems, AudioProcessorItems)
                )
            ) and isinstance(item, DictEmbeddingItems):
                return [audio_token_id]

            mask = out_mm_kwargs["audio_attention_masks"][item_idx]
            assert isinstance(mask, torch.Tensor)
            num_audio_tokens = int(mask.sum().item())
            return [audio_token_id] * num_audio_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_image_replacement,
            ),
            PromptReplacement(
                modality="audio",
                target=[audio_token_id],
                replacement=get_audio_replacement,
            ),
        ]

Before submitting a new issue...

  • [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

yanyongyu avatar Mar 05 '25 11:03 yanyongyu

If you want to completely bypass the multimodal processor, the easiest way would be to just override apply to return the input kwargs directly. (You probably still have to search for the feature tokens in the prompt IDs)

DarkLight1337 avatar Mar 05 '25 14:03 DarkLight1337

If i rewrite the apply method of the processor, how to handle the dummy input data? Should i change the dummy input builder at the same time?

yanyongyu avatar Mar 05 '25 14:03 yanyongyu

To keep things simple, I would add a flag like "is_processed" to your inputs and only use your own logic if you detect that flag in the inputs when calling apply. This lets you keep the existing code intact.

DarkLight1337 avatar Mar 05 '25 14:03 DarkLight1337

After rewriting the apply method, the processor seems works successfully. But I encountered a concurrency problem.

When i enqueued two generation request (with same input), the model received the first request's prefill, and then performs decoding in the second forward call. Everything is ok when first forward call with input ids [seq_len], one image pixels and one audio values. The length of the input token ids received in second forward call is [seq_len + 1] and the mm kwargs contains both two requests' image and audio. This causes the mismatch of feature embedding token count. I'm not sure why this happens. The engine is v0 with vllm 0.7.3. Thank you for helping 🙏.

yanyongyu avatar Mar 06 '25 06:03 yanyongyu

What does your code look like now?

DarkLight1337 avatar Mar 06 '25 07:03 DarkLight1337

Here is the processor main logic and the model:

class CustomMultiModalProcessor(BaseMultiModalProcessor[CustomProcessingInfo]):
    def _extract_image_placeholders(
        self, prompt_ids: list[int], image_data: dict[str, Any]
    ) -> list[PlaceholderFeaturesInfo]:
        # get feature token length from input

        start_idx = 0
        placeholders: list[PlaceholderFeaturesInfo] = []
        for i, feature_token_length in enumerate(feature_token_lengths):
            feature_start = prompt_ids.index(image_token_id, start_idx)
            feature_end = feature_start + feature_token_length
            feature_tokens = prompt_ids[feature_start:feature_end]
            assert set(feature_tokens) == {image_token_id}
            placeholders.append(
                PlaceholderFeaturesInfo(
                    modality="image",
                    item_idx=i,
                    start_idx=feature_start,
                    tokens=feature_tokens,
                )
            )
            start_idx = feature_end
        return placeholders

    def _extract_audio_placeholders(
        self, prompt_ids: list[int], audio_data: dict[str, Any]
    ) -> list[PlaceholderFeaturesInfo]:
        # get feature token length from input

        start_idx = 0
        placeholders: list[PlaceholderFeaturesInfo] = []
        for i, feature_token_length in enumerate(feature_token_lengths):
            feature_start = prompt_ids.index(audio_token_id, start_idx)
            feature_end = feature_start + feature_token_length
            feature_tokens = prompt_ids[feature_start:feature_end]
            assert set(feature_tokens) == {audio_token_id}
            placeholders.append(
                PlaceholderFeaturesInfo(
                    modality="audio",
                    item_idx=i,
                    start_idx=feature_start,
                    tokens=feature_tokens,
                )
            )
            start_idx = feature_end
        return placeholders

    def _extract_placeholders(
        self,
        prompt_ids: list[int],
        image_data: dict[str, Any],
        audio_data: dict[str, Any],
    ) -> dict[str, list[PlaceholderFeaturesInfo]]:
        return {
            "image": self._extract_image_placeholders(prompt_ids, image_data),
            "audio": self._extract_audio_placeholders(prompt_ids, audio_data),
        }

    def apply(
        self,
        prompt: str | list[int],
        mm_data: Mapping[str, Any | list[Any]],
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalInputs:
        if (
            isinstance(prompt, list)
            and isinstance(image_data := mm_data.get("image", None), dict)
            and isinstance(audio_data := mm_data.get("audio", None), dict)
        ):
            prompt_ids = prompt
            tokenizer = self.info.get_tokenizer()
            prompt = decode_tokens(tokenizer, prompt_ids)

            hf_inputs = BatchFeature({**image_data, **audio_data})
            mm_kwargs = MultiModalKwargs.from_hf_inputs(
                hf_inputs, self._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
            )

            mm_placeholders = self._extract_placeholders(
                prompt_ids, image_data, audio_data
            )

            mm_placeholder_ranges = {
                modality: [item.to_range() for item in placeholders]
                for modality, placeholders in mm_placeholders.items()
            }
            return MultiModalInputs(
                type="multimodal",
                prompt=prompt,
                prompt_token_ids=prompt_ids,
                mm_kwargs=mm_kwargs,
                mm_hashes=None,
                mm_placeholders=mm_placeholder_ranges,
            )
        return super().apply(prompt, mm_data, hf_processor_mm_kwargs)

@MULTIMODAL_REGISTRY.register_processor(
    CustomMultiModalProcessor,
    info=CustomProcessingInfo,
    dummy_inputs=CustomDummyInputsBuilder,
)
class CustomModelForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):

    def _parse_and_validate_image_input(self, **kwargs: object) -> ImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            pixel_values = self._validate_and_reshape_image_tensor(
                pixel_values, "image pixel values"
            )
            image_grid_thw = self._validate_and_reshape_image_tensor(
                image_grid_thw, "image grid_thw"
            )

            if not isinstance(pixel_values, torch.Tensor | list):
                raise ValueError(
                    "Incorrect type of image pixel values. "
                    f"Got type: {type(pixel_values)}"
                )

            return ImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            image_embeds = self._validate_and_reshape_image_tensor(
                image_embeds, "image embeds"
            )
            image_grid_thw = self._validate_and_reshape_image_tensor(
                image_grid_thw, "image grid_thw"
            )

            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError(
                    "Incorrect type of image embeddings. "
                    f"Got type: {type(image_embeds)}"
                )
            return ImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_audio_input(self, **kwargs: object) -> AudioInputs | None:
        audio_values = kwargs.pop("audio_values", None)
        audio_attention_masks = kwargs.pop("audio_attention_masks", None)

        if audio_values is None or audio_attention_masks is None:
            return

        audio_values = self._validate_and_reshape_audio_tensor(audio_values)
        audio_attention_masks = self._validate_and_reshape_audio_mask_tensor(
            audio_attention_masks
        )
        if not isinstance(audio_attention_masks, torch.Tensor):
            raise ValueError(
                "Incorrect type of audio attention masks. "
                f"Got type: {type(audio_attention_masks)}"
            )
        audio_values = audio_values.to(self.dtype)

        return AudioInputs(
            audio_values=audio_values, audio_attention_masks=audio_attention_masks
        )

    def _process_image_input(
        self, image_input: ImageInputs
    ) -> tuple[torch.Tensor, ...]:
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype)
        else:
            pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype)
            # compute image feature
            image_embeds = xxx

        # Split concatenated embeddings for each image item.
        merge_size = self.vision_tower.spatial_merge_size
        sizes = grid_thw.prod(-1) // merge_size // merge_size

        return image_embeds.split(sizes.tolist())

    def _process_audio_input(self, audio_input: AudioInputs):
        audio_values = audio_input["audio_values"]
        audio_attention_masks = audio_input["audio_attention_masks"]

        # compute audio feature
        audio_features: torch.Tensor = xxx
        return output_features

    def get_multimodal_embeddings(
        self, **kwargs
    ) -> tuple[Iterable[torch.Tensor] | None, Iterable[torch.Tensor] | None] | None:
        vision_embeddings = None
        if image_input := self._parse_and_validate_image_input(**kwargs):
            vision_embeddings = self._process_image_input(image_input)

        audio_embeddings = None
        if audio_input := self._parse_and_validate_audio_input(**kwargs):
            audio_embeddings = self._process_audio_input(audio_input)

        return vision_embeddings, audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: tuple[
            Iterable[torch.Tensor] | None, Iterable[torch.Tensor] | None
        ]
        | None = None,
        attn_metadata: AttentionMetadata | None = None,
    ) -> torch.Tensor:
        inputs_embeds: torch.Tensor = self.language_model.get_input_embeddings(
            input_ids
        )
        if multimodal_embeddings is not None:
            vision_embeddings, audio_embeddings = multimodal_embeddings
            if vision_embeddings is not None:
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids,
                    inputs_embeds,
                    list(vision_embeddings),
                    self.config.image_token_index,
                )
            if audio_embeddings is not None:
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids,
                    inputs_embeds,
                    list(audio_embeddings),
                    self.config.audio_token_index,
                )
        return inputs_embeds

	def forward(  # type: ignore
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: list[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings)

            input_ids = None  # type: ignore

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            kv_caches,
            attn_metadata,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

yanyongyu avatar Mar 06 '25 08:03 yanyongyu

Did you enable chunked prefill? This is only supported in V1 even when V1-compatible processor is used.

DarkLight1337 avatar Mar 06 '25 09:03 DarkLight1337

Did you enable chunked prefill?

Yes. I enabled this.

yanyongyu avatar Mar 06 '25 09:03 yanyongyu

disabling the chunked prefill seems not solve the issue. The model forward still received too many multimodal input when decode.

ValueError: Attempted to assign 660 + 660 = 1320 multimodal tokens to 660 placeholders

yanyongyu avatar Mar 06 '25 09:03 yanyongyu

Can you show the shape of the tensors in the processor outputs? Are the placeholder ranges being set correctly?

I think it's difficult to debug this in detail unless you share your fork.

DarkLight1337 avatar Mar 06 '25 10:03 DarkLight1337

The issue only occurs when one request is in decoding stage and the other is in prefill stage.

Here is some detailed info about the tensor shapes:

# The first request starts here
processor input_ids:  953
processor pixel_values:  torch.Size([2640, 1176])
processor grid_thw:  tensor([[ 4, 22, 30]])
processor placeholders:  [{'offset': 162, 'length': 660}]
processor audio_values:  torch.Size([1, 128, 3000])
processor audio_attention_masks:  torch.Size([1, 1500])
processor placeholders:  [{'offset': 824, 'length': 123}]
model input_ids:  torch.Size([953])
attn_metadata: attn_metadata.num_prefills=1, attn_metadata.num_prefill_tokens=953, attn_metadata.num_decode_tokens=0
model input_embeds None
model pixel_values:  torch.Size([2640, 1176])
model audio_values:  torch.Size([1, 128, 3000])
model audio_attention_masks:  torch.Size([1, 1500])
model vision_embeddings:  [torch.Size([660, 3584])]
model audio_embeddings:  [torch.Size([123, 3584])]

# the first request starts to decode here
model input_ids:  torch.Size([1])
attn_metadata: attn_metadata.num_prefills=0, attn_metadata.num_prefill_tokens=0, attn_metadata.num_decode_tokens=1
model input_embeds None

# another request receives now
processor input_ids:  953
processor pixel_values:  torch.Size([2640, 1176])
processor grid_thw:  tensor([[ 4, 22, 30]])
processor placeholders:  [{'offset': 162, 'length': 660}]
processor audio_values:  torch.Size([1, 128, 3000])
processor audio_attention_masks:  torch.Size([1, 1500])
processor placeholders:  [{'offset': 824, 'length': 123}]
model input_ids:  torch.Size([954])
# 1 decode token and 953 prefill token
attn_metadata: attn_metadata.num_prefills=1, attn_metadata.num_prefill_tokens=953, attn_metadata.num_decode_tokens=1
model input_embeds None
# the mm input batches are doubled and mismatch the processor output
model pixel_values:  torch.Size([5280, 1176])
model audio_values:  torch.Size([2, 128, 3000])
model audio_attention_masks:  torch.Size([2, 1500])
model vision_embeddings:  [torch.Size([660, 3584]), torch.Size([660, 3584])]
...
ValueError: Attempted to assign 660 + 660 = 1320 multimodal tokens to 660 placeholders

yanyongyu avatar Mar 06 '25 11:03 yanyongyu

In get_multimodal_embeddings, you should make sure None values aren't in the returned tuple.

DarkLight1337 avatar Mar 06 '25 13:03 DarkLight1337

I referred to this code. Should i parse the multimodal data in the order of the input placehold tokens? If the input tokens look like [text, image, audio, text, image, audio], i should also rearrange the multimodal_embeddings tensors into this sequence and then perform the merge op.

https://github.com/vllm-project/vllm/blob/ed6e9075d31e32c8548b480a47d1ffb77da1f54c/vllm/model_executor/models/qwen2_5_vl.py#L949-L971

yanyongyu avatar Mar 06 '25 14:03 yanyongyu

We don't support interleaved modality inputs yet. You just have to order the embeddings in the order of the first appearance of that modality.

DarkLight1337 avatar Mar 06 '25 14:03 DarkLight1337

I am a little confused. How this get_multimodal_embedding method affect the multimodal kwargs received during model forward? This issue is that when the input ids contain both decode and prefill stages, the content of the multimodal kwargs is not correct.

I tried to debug this and found that when building model input, there are two inter_data in inter_data_list and both of them have multi_modal_kwargs. Two inter_data's moulti_modal_kwargs are batched to input mm kwargs. https://github.com/vllm-project/vllm/blob/ed6e9075d31e32c8548b480a47d1ffb77da1f54c/vllm/worker/model_runner.py#L974-L979

yanyongyu avatar Mar 07 '25 04:03 yanyongyu

How this get_multimodal_embedding method affect the multimodal kwargs received during model forward?

That is the only thing the stands out to me, hence I mentioned it. Without actually debugging the code, it's difficult to see where the problem is coming from.

DarkLight1337 avatar Mar 07 '25 04:03 DarkLight1337

Perhaps @ywang96 @WoosukKwon have a better idea of this? I'm not really involved with the scheduler.

DarkLight1337 avatar Mar 07 '25 04:03 DarkLight1337

After digging into the model input builder, i find more info about this issue.

The builder gets the mm kwargs from the inter_data_list here: https://github.com/vllm-project/vllm/blob/ed6e9075d31e32c8548b480a47d1ffb77da1f54c/vllm/worker/model_runner.py#L974-L979

The inter_data is append when sequence group is added: https://github.com/vllm-project/vllm/blob/ed6e9075d31e32c8548b480a47d1ffb77da1f54c/vllm/worker/model_runner.py#L744-L754

And the inter_data's mm kwargs are from the sequence group: https://github.com/vllm-project/vllm/blob/ed6e9075d31e32c8548b480a47d1ffb77da1f54c/vllm/worker/model_runner.py#L674-L695

When there is single decode request, the mm data in SequenceGroupMetadata is None. But when there are one decode request and one prefill request, the decode request's mm_data in SequenceGroupMetadata is not None. This causes the model forward getting wrong count of mm data.

yanyongyu avatar Mar 07 '25 08:03 yanyongyu

Can you try using vLLM V1 and see if you get a similar problem? Since the scheduler is different in V1, simply switching to it might solve the issue.

DarkLight1337 avatar Mar 07 '25 08:03 DarkLight1337

In get_multimodal_embeddings, you should make sure None values aren't in the returned tuple.

After changing the above error and enabling vllm v1 engine, everything works. 🥹

Thanks for your help! @DarkLight1337

yanyongyu avatar Mar 07 '25 10:03 yanyongyu

Hi, I am trying to do the same but haven't been able to figure out what arguments to pass to configure AsyncLLMEngine to use my custom model processor and model implementation (the docs seem to explain how to register the model but not accessing it).

I checked vllm.config and it seems ModelConfig always aims to load a HF config, so I guess you changed in the config.json the architecture and model_type to point to this one?

miguelalba96 avatar Mar 12 '25 13:03 miguelalba96

You can pass --hf-overrides which lets you override the HF config via command line. There you can set the architectures field to use your own model.

DarkLight1337 avatar Mar 12 '25 13:03 DarkLight1337

Can Qwen2vl/qwen2.5vl use CustomMultiModalProcessor to bypass multimodal processor logic without modified source code?

pjgao avatar Mar 18 '25 08:03 pjgao

You can register OOT models via plugins. However since Qwen2-VL is already defined inside vLLM, you will have to rename the model and use --hf-overrides to set the architecture name to use your custom implementation.

DarkLight1337 avatar Mar 18 '25 09:03 DarkLight1337