LLaVA
LLaVA copied to clipboard
Encountered problems when handling query with multiple images.
Question
Encountered an issue while processing inquiries involving multiple images in:
llava/model/llava_arch.py at Line 119. image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
I'm questioning the use of flatten on tensor x at this point. Based on the subsequent code, it seems more appropriate to split it into a list rather then flatten it.
In my understanding, when type(images) is list
(line 114), each element in the ’images list‘ stores the images contained in the i-th sentence of the current batch. At this point, the ‘image_features list‘ stores the tensors resulting of 'self.encode_images' for each sentence in the batch. Therefore, the shape of x should be (i think) a tensor: [num, len, dim], where num represents the number of images for the current sentence. Flattening it combines num and len, causing issues when executing the code at line 178:
cur_image_features = image_features[cur_image_idx]
The problem is that 'cur_image_idx' exceeds the range of image_features.
This is because 'cur_image_idx' can reach a maximum value equal to the sum of the total number of
I think that changing the flatten operation in Line 119 to create a list for all images, rather than concatenating all images of the same sentence in advance, is necessary to align with the use of ‘cur_image_idx‘.
+1 The process at the block from line 114 is just to concat all the images in the batch in order and encode them. Here is my understanding of the code and suggested modifications.
if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0) # -> (#images in a batch, C, H, W)
image_features = self.encode_images(concat_images) # -> (#images in a batch, #patches per image, hidden_dim)
# split_sizes = [image.shape[0] for image in images]
# image_features = torch.split(image_features, split_sizes, dim=0) # Divides into each sentences
# image_features = [x.flatten(0, 1).to(self.device) for x in image_features] # -> (B, #images in a sentence * #patches, hidden_dim)
image_features = image_features.to(self.device) # This might be needed
@Yuki-Imajuku @smzzl You got any solution for this then ?
@haotian-liu could you please have a look at this issue ?
@Yuki-Imajuku @smzzl You got any solution for this then ?
No, i use start and end token to rewrite it.
oh, I had to remove