LAVIS icon indicating copy to clipboard operation
LAVIS copied to clipboard

BLIP2 process batch of text for image-text matching

Open AnaRhisT94 opened this issue 1 year ago • 2 comments

Hi,

Following the example here: https://github.com/salesforce/LAVIS/blob/main/examples/blip2_image_text_matching.ipynb

How can I use self.text_processors["eval"](text) to process a batch of text without writing a for loop? Previous BLIP has it and was curious if I can achieve the same effect with BLIP2.

Error:

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (input_ids in this case) have excessive nesting (inputs type list where type int is expected).

Basically the error means that because the length of the text isn't the same I need somewhere to use padding and truncation to true, I still didn't find that appropriate place I should use it. But I don't want to enable truncation to true because I don't want to lose characters in the text I feed. Hence, increase the number of tokens to feed the tokenizer can be a good solution but I'm not sure how to achieve it. Any ideas?

Here's my code:

from PIL import Image
import requests
import torch
import os
from pathlib import Path
import wget
from lavis.models import load_model_and_preprocess

class BLIP_Captioner():

    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model, self.vis_processors, self.text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=self.device, is_eval=True)
    
    def generate_matching(self, image, caption):

        with torch.no_grad():
            img = self.vis_processors["eval"](image).unsqueeze(0).to(self.device))
            itc_scores = self.model({"image": img.repeat(150, 1, 1, 1), "text_input": text}, match_head='itc')

def main():

    blip_instance = BLIP_Captioner()

    img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' 
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

    blip_instance.generate_matching(raw_image, text = ["A picture of woman", "A picture of man", "A picture of dog"] * 50)

if __name__ == '__main__':
    main()

AnaRhisT94 avatar Feb 19 '23 13:02 AnaRhisT94

Hi, I wonder if you have solved the problem since I also encounter this issue

JeffreyYzh avatar Nov 19 '23 04:11 JeffreyYzh

Hi, has either of you found a solution?

Coronal-Halo avatar Feb 06 '24 19:02 Coronal-Halo