Video-LLaVA
Video-LLaVA copied to clipboard
How to batch evaluate in inference?
Hi, how can I make the inference code to evaluate videos in batch? I naively concatenated the tensor in dimension 0 and get this error.
Can you help me the figure out it? Thanks.
I did something similar, but instead of concating, i used a for loop. Got the exact same error. Were you able to resolve?
I did something similar, but instead of concating, i used a for loop. Got the exact same error. Were you able to resolve?
I don't know the detail about this problem, however, I found that the input_id tensor should not exceed batch size 1. E.g. [1, 50] can work but [4, 50] can't. Hope that it would be helpful.
Hi, how can I make the inference code to evaluate videos in batch? I naively concatenated the tensor in dimension 0 and get this error.
Can you help me the figure out it? Thanks.
i meet the same promble when i put inference in a for loop, it is strange
any solution?
I'm not sure how to solve this problem, but I implemented batch eval using this solution. I hope this can help you.
I added a def generate
in the class LlavaLlamaForCausalLM
of videollava/model/language_model
, just like the original code of LLaVA:
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
Then at the infer position, I save the text part input_ids
into a list
, and then use torch.nn.utils.rnn.pad_sequence
to supplement its padding to the same length, and then move the padding part forward to achieve (57 to Line 63, and use a def roll_padding_to_front
):
import torch
from torch.nn.utils.rnn import pad_sequence
from videollava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from videollava.conversation import conv_templates, SeparatorStyle
from videollava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
def roll_padding_to_front(padded_input_ids, padding_value=0):
padding_lengths = (padded_input_ids == padding_value).long().sum(dim=1)
rolled_input_ids = torch.stack([torch.roll(input_id, shifts=padding_length.item()) for input_id, padding_length in zip(padded_input_ids, padding_lengths)])
return rolled_input_ids
def the_generate(
tokenizer , model , model_name , video_processor ,
questions , video_path , conv_mode, device, temperature, max_new_tokens
):
video_outputs = video_processor(video_path, return_tensors='pt')
video_tensor = video_outputs['pixel_values']
video_prompts = video_outputs['prompts']
# 'video_prompt' containing video frames and time points, same as videollama
if type(video_tensor) is list:
videos_tensor = [video.to(device, dtype=torch.float16) for video in video_tensor]
else:
videos_tensor = video_tensor.to(device, dtype=torch.float16)
inputs_ids = []
stopping_criterias = []
for question , video_prompt in zip(questions , video_prompts):
conv_m = "llava_v1"
if conv_mode is not None and conv_m != conv_mode:
print(
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
conv_m, conv_mode, conv_mode
)
)
else:
conv_mode = conv_m
conv = conv_templates[conv_mode].copy()
roles = conv.roles
question = ' '.join([DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames) + '\n' \
+ video_prompt + question
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
.unsqueeze(0)
)
inputs_ids.append(input_ids.squeeze(0))
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
stopping_criterias.append(stopping_criteria)
padded_input_ids = pad_sequence(
inputs_ids,
batch_first=True,
padding_value=tokenizer.pad_token_id
).to(device)
# move padding tokens ahead
rolled_input_ids = roll_padding_to_front(padded_input_ids)
with torch.inference_mode():
output_ids = model.generate(
rolled_input_ids,
images=videos_tensor,
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
use_cache=True,
stopping_criteria=stopping_criterias
)
outputs = tokenizer.batch_decode(
output_ids,
skip_special_tokens=True
)
outputs = [x.strip() for x in outputs]
return outputs
I'm having same problem :-( Any advice?
Not sure why @xiningin code also return me error :-(
If you are running inference in a loop, make sure that you reset the conversation template conv
. Otherwise, the prompt you are passing to the model every time includes all the messages. Therefore, when you process the second video, there's going to be 16 images instead of 8 due to this line. That is, 8 from the prompt corresponding to the first video and 8 corresponding to the prompt for the second video. So, when the code gets to this point, it will try to process 16 image features, but image_features
will only have 8 elements. Hence the list index out of range error.
@RaulKite Maybe we don't use the same dependency of some package. Here is my main version of packages, you can try to install the same version as mine, or you can show the error you encountered at that time.
torch 2.2.1
torchaudio 2.2.1
torchinfo 1.8.0
torchstat 0.0.7
torchvision 0.15.2
tqdm 4.66.2
transformers 4.31.0
videollava 1.0.0 .../Video-LLaVA
pytorchvideo 0.1.5
pycocoevalcap 1.2
pycocotools 2.0.7
omegaconf 2.3.0
And I use def the_generate()
in a loop like this:
for batch in tqdm(dataloader):
......
outputs = the_generate(........)
.......