Efficient-Computing
Efficient-Computing copied to clipboard
The script of generating text for GPT4Image
I'm trying to transfer GPT4Image from ImageNet to a custom dataset. How can I generate the corresponding text for the batch of images through MiniGPT4? Could you please provide a reference script?
I used a very old version of MiniGPT4, and the last commit might be "commit-22d8888" on May 1, 2023. You might access those lagacy codes throught this link: https://github.com/Vision-CAIR/MiniGPT-4/tree/22d8888ca2cf0aac862f537e7d22ef5830036808
Place the following script "generate_captions.py" at the root folder of MiniGPT4 and modify as needed ##############################################
"""generate_captions.py"""
import argparse
import random
from time import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import os
import torchvision
from copy import deepcopy
import datetime
from pathlib import Path
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION
# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
chat_state = CONV_VISION.copy()
input_prompt = 'Describe this image. Do not say anything that you are not sure.'
def prepare_model(args):
print('Initializing Chat')
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
return chat
def get_single_caption(chat, image, input_prompt):
img_list = []
chat_state.messages = []
image_emb, _ = chat.model.encode_img(image)
img_list.append(image_emb)
chat_state.append_message(chat_state.roles[0], "<Img><ImageHere></Img>")
chat_state.messages[-1][1] = ' '.join([chat_state.messages[-1][1], input_prompt])
llm_message = chat.answer(
conv=chat_state, img_list=img_list, num_beams=1, temperature=0.1,
max_new_tokens=300, max_length=2000, return_text_only=True)
return llm_message
def make_captions(args):
start, num, bsize = args.start, args.num, args.bsize
end = start + num - 1
assert 0 == num % bsize
assert 0 == start % bsize
Path(args.save_dir).mkdir(parents=True, exist_ok=True)
chat_model = prepare_model(args)
dataset = torchvision.datasets.ImageFolder(
args.data_path,
transform=deepcopy(chat_model.vis_processor)
)
all_result = []
save_fname = f'minigpt4_caption_imagenet_train_{start}_{end}.pth'
time_start = time()
print(f'##@@ START = {start} NUM = {num}')
for index in range(start, start+num):
img, _ = dataset.__getitem__(index)
pred_string = get_single_caption(
chat_model, img.unsqueeze(0).to(chat_model.device), input_prompt
)
print(pred_string)
all_result.append(pred_string)
if index == len(dataset) - 1: # handle last batch
print('this is the last sample')
save_fname = save_fname.replace(str(end), str(index))
break
print('##@@ END\n', 'total running time :', str(datetime.timedelta(seconds=int(time()-time_start))))
save_fname = os.path.join(args.save_dir, save_fname)
torch.save(all_result, save_fname)
print('result saved at', save_fname)
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--data_path", type=str, default='/path/to/imagenet/train')
parser.add_argument("--save_dir", type=str, default='/cache/output/minigpt4_captions')
parser.add_argument("--cfg_path", type=str, default='eval_configs/minigpt4_eval.yaml')
parser.add_argument("--gpu_id", default=0, type=int)
parser.add_argument('--start', default=0, type=int)
parser.add_argument('--num', default=100, type=int)
parser.add_argument('--bsize', default=5, type=int)
parser.add_argument(
"--options", nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
make_captions(args)