LLaVA icon indicating copy to clipboard operation
LLaVA copied to clipboard

[Usage] Batch inference with Llava 1.5

Open kimihailv opened this issue 8 months ago • 5 comments

Describe the issue

Currenty, only inference with batch_size=1 is possible. If I undestood correctly, these things should be changed to make batch inference:

  1. position_ids should be shifted, because of left padding
  2. Attention mask should be passed and transformed for multimodal forward

Maybe someone has managed to adapt the code?

kimihailv avatar Oct 30 '23 16:10 kimihailv

Here's a processor that I wrote to make it work.

from LLaVA.llava.constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from LLaVA.llava.conversation import conv_templates
from LLaVA.llava.mm_utils import tokenizer_image_token


class LlaVaProcessor:
    def __init__(self, tokenizer, image_processor, mm_use_im_start_end):
        self.mm_use_im_start_end = mm_use_im_start_end
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.conv_mode = "llava_v1"

    def load_demo_images(image_files: Union[List[str], str]):
        if type(image_files) is list:
            out = []
            for image_file in image_files:
                image = Image.open(image_file).convert("RGB")
                out.append(image)
        else:
            out = Image.open(image_files).convert("RGB")
        return out

    # TODO: refactor this, not working
    def get_processed_tokens_demo(self, text: str, image_files: Union[List[str], str]):
        if self.mm_use_im_start_end:
            qs = (
                qs
                + "\n"
                + DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + DEFAULT_IM_END_TOKEN
                + "\n"
                + DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + DEFAULT_IM_END_TOKEN
            )
        else:
            qs = (
                qs
                + "\n"
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + "\n"
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
            )

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        images = self.load_demo_images(image_files)
        image_tensor = torch.stack(
            [self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
        )

        input_ids = (
            tokenizer_image_token(text, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
        )

        return image_tensor, input_ids

    def format_text(self, text: str):
        if self.mm_use_im_start_end:
            text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + text
        else:
            text = DEFAULT_IMAGE_TOKEN + "\n" + text

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        text = conv.get_prompt()

        return text

    def load_image(self, image_path: str):
        return Image.open(image_path).convert("RGB")

    @staticmethod
    def pad_sequence_to_max_length(sequence, max_length, padding_value=0):
        """Pad a sequence to the desired max length."""
        if len(sequence) >= max_length:
            return sequence
        return torch.cat([torch.full((max_length - len(sequence),), padding_value, dtype=sequence.dtype), sequence])

    def get_processed_tokens(self, text: str, image_path: str):
        prompt = self.format_text(text)
        image = self.load_image(image_path)

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)
        image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]

        return image_tensor, input_ids

    def get_processed_tokens_batch(self, batch_text: List[str], image_paths: List[str]):
        prompt = [self.format_text(text) for text in batch_text]
        images = [self.load_image(image_path) for image_path in image_paths]

        batch_input_ids = [
            tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in prompt
        ]

        # Determine the maximum length of input_ids in the batch
        max_len = max([len(seq) for seq in batch_input_ids])
        # Pad each sequence in input_ids to the max_len
        padded_input_ids = [self.pad_sequence_to_max_length(seq.squeeze(), max_len) for seq in batch_input_ids]
        batch_input_ids = torch.stack(padded_input_ids)

        batch_image_tensor = self.image_processor(images, return_tensors="pt")["pixel_values"]

        return batch_image_tensor, batch_input_ids

You can now do inference

                from LLaVA.llava.conversation import (SeparatorStyle,
                                                      conv_templates)
                from LLaVA.llava.mm_utils import KeywordsStoppingCriteria

                conv = conv_templates[processor.conv_mode].copy()
                stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
                keywords = [stop_str]
                stopping_criteria = (
                    [KeywordsStoppingCriteria(keywords, processor.tokenizer, input_ids)]
                    if conv.version == "v0"
                    else None
                )
                input_ids = batch["input_ids"]
                image_tensor = batch["image_tensors"]
                input_ids = input_ids.cuda()

                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.half().cuda(),
                    num_beams=self.args.num_beams,
                    max_new_tokens=self.args.max_length,
                    length_penalty=self.args.length_penalty,
                    use_cache=True,
                    stopping_criteria=stopping_criteria,
                    do_sample=self.args.do_sample,
                    temperature=self.args.temperature,
                    num_return_sequences=self.args.num_return_sequences,
                )
                generated_outputs = processor.tokenizer.batch_decode(
                    output_ids[:, input_ids.shape[1] :], skip_special_tokens=True
                )
                generated_outputs = [out.strip() for out in generated_outputs]
                generated_outputs = [
                    out[: -len(stop_str)] if out.endswith(stop_str) else out for out in generated_outputs
                ]

You can also check my vqa-prompting codebase for full support!

rabiulcste avatar Nov 02 '23 16:11 rabiulcste

Hi,

Batched inference with LLaVa is supported in Hugging Face Transformers. See here for an example: https://github.com/huggingface/transformers/blob/a49f4acab3c1eea82907e12f82eafbd4673deb39/tests/models/llava/test_modeling_llava.py#L245.

NielsRogge avatar Dec 12 '23 14:12 NielsRogge

In case anyone else finds this, here is a sample of working batch inference code based on the link above

        prompt_temp = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
        
        prompts=[]
        images = []

        for user_question, base64_image in zip(user_questions, images_data):
            prompt = prompt_temp.format(user_question)
            prompts.append(prompt)

            image_data = base64.b64decode(base64_image)
            image = Image.open(BytesIO(image_data))
            images.append(image)


        # Perform batch inference
        inputs = self.processor(prompts, images=images, return_tensors="pt", padding=True).to("cuda:0")
        output = self.model.generate(**inputs, max_new_tokens=4000)
        
        answer = self.processor.batch_decode(output, skip_special_tokens=True)

david-vectorflow avatar Apr 22 '24 22:04 david-vectorflow

Hi,

Batched inference with LLaVa is supported in Hugging Face Transformers. See here for an example: https://github.com/huggingface/transformers/blob/a49f4acab3c1eea82907e12f82eafbd4673deb39/tests/models/llava/test_modeling_llava.py#L245.

Hey, @NielsRogge I've stumbled upon this issue today. It seems that the same code does not work for LlavaNextForConditionalGeneration. Is batched inference for LlavaNext models supported in some other ways?

For reference, it crashed when trying to stacking new_image_features

File ~/miniconda3/envs/vlm_safety_eval/lib/python3.10/site-packages/transformers/models/llava_next/modeling_llava_next.py:553, in LlavaNextForConditionalGeneration.forward(self, input_ids, pixel_values, image_sizes, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, vision_feature_select_strategy, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    551         image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
    552     new_image_features.append(image_feature)
--> 553 image_features = torch.stack(new_image_features, dim=0)
    555 inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
    556     image_features, inputs_embeds, input_ids, attention_mask, labels
    557 )
    558 if labels is None:

RuntimeError: stack expects each tensor to be equal size, but got [2144, 4096] at entry 0 and [2340, 4096] at entry 1

g8a9 avatar Apr 30 '24 09:04 g8a9

Yes I'm aware of that, this is being addressed in https://github.com/huggingface/transformers/pull/29850

It will be part of the next Transformers release!

NielsRogge avatar Apr 30 '24 09:04 NielsRogge