LLaVA
LLaVA copied to clipboard
[Usage] Batch inference with Llava 1.5
Describe the issue
Currenty, only inference with batch_size=1 is possible. If I undestood correctly, these things should be changed to make batch inference:
- position_ids should be shifted, because of left padding
- Attention mask should be passed and transformed for multimodal forward
Maybe someone has managed to adapt the code?
Here's a processor that I wrote to make it work.
from LLaVA.llava.constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from LLaVA.llava.conversation import conv_templates
from LLaVA.llava.mm_utils import tokenizer_image_token
class LlaVaProcessor:
def __init__(self, tokenizer, image_processor, mm_use_im_start_end):
self.mm_use_im_start_end = mm_use_im_start_end
self.tokenizer = tokenizer
self.image_processor = image_processor
self.conv_mode = "llava_v1"
def load_demo_images(image_files: Union[List[str], str]):
if type(image_files) is list:
out = []
for image_file in image_files:
image = Image.open(image_file).convert("RGB")
out.append(image)
else:
out = Image.open(image_files).convert("RGB")
return out
# TODO: refactor this, not working
def get_processed_tokens_demo(self, text: str, image_files: Union[List[str], str]):
if self.mm_use_im_start_end:
qs = (
qs
+ "\n"
+ DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ DEFAULT_IM_END_TOKEN
+ "\n"
+ DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ DEFAULT_IM_END_TOKEN
)
else:
qs = (
qs
+ "\n"
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ "\n"
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
)
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
images = self.load_demo_images(image_files)
image_tensor = torch.stack(
[self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
)
input_ids = (
tokenizer_image_token(text, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
)
return image_tensor, input_ids
def format_text(self, text: str):
if self.mm_use_im_start_end:
text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + text
else:
text = DEFAULT_IMAGE_TOKEN + "\n" + text
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
text = conv.get_prompt()
return text
def load_image(self, image_path: str):
return Image.open(image_path).convert("RGB")
@staticmethod
def pad_sequence_to_max_length(sequence, max_length, padding_value=0):
"""Pad a sequence to the desired max length."""
if len(sequence) >= max_length:
return sequence
return torch.cat([torch.full((max_length - len(sequence),), padding_value, dtype=sequence.dtype), sequence])
def get_processed_tokens(self, text: str, image_path: str):
prompt = self.format_text(text)
image = self.load_image(image_path)
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
return image_tensor, input_ids
def get_processed_tokens_batch(self, batch_text: List[str], image_paths: List[str]):
prompt = [self.format_text(text) for text in batch_text]
images = [self.load_image(image_path) for image_path in image_paths]
batch_input_ids = [
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in prompt
]
# Determine the maximum length of input_ids in the batch
max_len = max([len(seq) for seq in batch_input_ids])
# Pad each sequence in input_ids to the max_len
padded_input_ids = [self.pad_sequence_to_max_length(seq.squeeze(), max_len) for seq in batch_input_ids]
batch_input_ids = torch.stack(padded_input_ids)
batch_image_tensor = self.image_processor(images, return_tensors="pt")["pixel_values"]
return batch_image_tensor, batch_input_ids
You can now do inference
from LLaVA.llava.conversation import (SeparatorStyle,
conv_templates)
from LLaVA.llava.mm_utils import KeywordsStoppingCriteria
conv = conv_templates[processor.conv_mode].copy()
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = (
[KeywordsStoppingCriteria(keywords, processor.tokenizer, input_ids)]
if conv.version == "v0"
else None
)
input_ids = batch["input_ids"]
image_tensor = batch["image_tensors"]
input_ids = input_ids.cuda()
output_ids = model.generate(
input_ids,
images=image_tensor.half().cuda(),
num_beams=self.args.num_beams,
max_new_tokens=self.args.max_length,
length_penalty=self.args.length_penalty,
use_cache=True,
stopping_criteria=stopping_criteria,
do_sample=self.args.do_sample,
temperature=self.args.temperature,
num_return_sequences=self.args.num_return_sequences,
)
generated_outputs = processor.tokenizer.batch_decode(
output_ids[:, input_ids.shape[1] :], skip_special_tokens=True
)
generated_outputs = [out.strip() for out in generated_outputs]
generated_outputs = [
out[: -len(stop_str)] if out.endswith(stop_str) else out for out in generated_outputs
]
You can also check my vqa-prompting codebase for full support!
Hi,
Batched inference with LLaVa is supported in Hugging Face Transformers. See here for an example: https://github.com/huggingface/transformers/blob/a49f4acab3c1eea82907e12f82eafbd4673deb39/tests/models/llava/test_modeling_llava.py#L245.
In case anyone else finds this, here is a sample of working batch inference code based on the link above
prompt_temp = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
prompts=[]
images = []
for user_question, base64_image in zip(user_questions, images_data):
prompt = prompt_temp.format(user_question)
prompts.append(prompt)
image_data = base64.b64decode(base64_image)
image = Image.open(BytesIO(image_data))
images.append(image)
# Perform batch inference
inputs = self.processor(prompts, images=images, return_tensors="pt", padding=True).to("cuda:0")
output = self.model.generate(**inputs, max_new_tokens=4000)
answer = self.processor.batch_decode(output, skip_special_tokens=True)
Hi,
Batched inference with LLaVa is supported in Hugging Face Transformers. See here for an example: https://github.com/huggingface/transformers/blob/a49f4acab3c1eea82907e12f82eafbd4673deb39/tests/models/llava/test_modeling_llava.py#L245.
Hey, @NielsRogge I've stumbled upon this issue today. It seems that the same code does not work for LlavaNextForConditionalGeneration
. Is batched inference for LlavaNext models supported in some other ways?
For reference, it crashed when trying to stacking new_image_features
File ~/miniconda3/envs/vlm_safety_eval/lib/python3.10/site-packages/transformers/models/llava_next/modeling_llava_next.py:553, in LlavaNextForConditionalGeneration.forward(self, input_ids, pixel_values, image_sizes, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, vision_feature_select_strategy, labels, use_cache, output_attentions, output_hidden_states, return_dict)
551 image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
552 new_image_features.append(image_feature)
--> 553 image_features = torch.stack(new_image_features, dim=0)
555 inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
556 image_features, inputs_embeds, input_ids, attention_mask, labels
557 )
558 if labels is None:
RuntimeError: stack expects each tensor to be equal size, but got [2144, 4096] at entry 0 and [2340, 4096] at entry 1
Yes I'm aware of that, this is being addressed in https://github.com/huggingface/transformers/pull/29850
It will be part of the next Transformers release!