LLaVA-NeXT
LLaVA-NeXT copied to clipboard
Inference issue on Pretrained MLP and Original Qwen2-7B-Instruct model
Hi, I have used your codebase to pretrain an MLP on the "Qwen2-7B-Instruct" base model, incorporating the "openai/clip-vitlarge-patch14" encoder. The training process was smooth, with a noticeable reduction in loss. However, I'm encountering an unusual error during inference when attempting to load the model.
I've made some changes to the logic in the "prepare_inputs_labels_for_multimodal" function and have carefully looked at the shapes of the return variables. Based on this, everything appears to be matching correctly and returning as expected during the training phase. When debugging during inference, it seems that the attention mask shape matches the cur_input_id, but then the unintended behavior occurs, at which point cur_input_id only contains a shape of [0]
Traceback (most recent call last):
File "/data/users/LLaVA-NeXT/llava/eval/model_pku.py", line 226, in <module>
results, total_time = process_json_file(json_file_path, video_dir)
File "/data/users/LLaVA-NeXT/llava/eval/model_pku.py", line 166, in process_json_file
model_answer = process_video_question(video_path, question)
File "/data/users/LLaVA-NeXT/llava/eval/model_pku.py", line 132, in process_video_question
cont = model.generate(
File "/home/.miniconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data/users/LLaVA-NeXT/llava/model/language_model/llava_qwen.py", line 338, in generate
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
File "/home/.miniconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/.miniconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 1758, in generate
result = self._sample(
File "/home/.miniconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 2397, in _sample
outputs = self(
File "/home/.miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/.miniconda3/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/LLaVA-NeXT/llava/model/language_model/llava_qwen.py", line 280, in forward
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images1, images2, modalities, image_sizes, inference=inference)
File "/data/users/LLaVA-NeXT/llava/model/llava_arch.py", line 429, in prepare_inputs_labels_for_multimodal
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
File "/data/users/LLaVA-NeXT/llava/model/llava_arch.py", line 429, in <listcomp>
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
Here is my forward call:
video_frames = load_video(video_path, 32)
image_tensors = []
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].cuda()
image_tensors.append(frames)
# Prepare conversation input
conv_template = "qwen_1_5"
prompt = f"{DEFAULT_VIDEO_TOKEN}\n{question}"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, VIDEO_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
# input_ids = preprocess_qwen({'from': 'gpt','value': prompt}, tokenizer, has_video=True).cuda()
image_sizes = [frame.shape[:2] for frame in video_frames]
img_tensor = image_tensors[0].unsqueeze(0).to(dtype=torch.float16).cuda()
stopping_criteria = KeywordsStoppingCriteria("<|endoftext|>", tokenizer, input_ids)
# Generate response
# with torch.inference_mode():
cont = model.generate(
input_ids,
images1=img_tensor,
images2=img_tensor.clone(),
image_sizes=image_sizes,
# do_sample=True,
do_sample=False,
temperature=0.0,
max_new_tokens=4096,
modalities=["video"],
stopping_criteria=[stopping_criteria]
)
IndexError: The shape of the mask [1096] at index 0 does not match the shape of the indexed tensor [1] at index 0
Here is more context. Notice that both cases are met from one inference call.
class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaQwenConfig
def __init__(self, config):
# super(Qwen2ForCausalLM, self).__init__(config)
Qwen2ForCausalLM.__init__(self, config)
config.model_type = "llava_qwen"
config.rope_scaling = None
self.model = LlavaQwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images1: Optional[torch.FloatTensor] = None,
images2: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
modalities: Optional[List[str]] = ["image"],
dpo_forward: Optional[bool] = False,
cache_position=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
# breakpoint()
if inputs_embeds is None:
print("This case is met")
print("yet the else case was also met?")
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images1, images2, modalities, image_sizes, inference=inference)
else:
print("This case is not met")
This case is not met
This case is met
yet the else case was also met?
Any suggestions?