LLaVA
LLaVA copied to clipboard
[Usage] Must I reload the model when I want to inference on a new image?
Describe the issue
I think the time to load model is very long, so try to reuse the model when inferring in a new image. But encounter the issue below, so is it possible to do this? How should I write the code? Modified from llava/serve/cli.py
def main(args):
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
...
while True:
image_file = input("Please input image path:")
image = load_image(image_file)
image_size = image.size
# Similar operation in model_worker.py
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
while True:
try:
inp = input(f"{roles[0]}: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
break
print(f"{roles[1]}: ", end="")
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
image = None
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image_size],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
streamer=streamer,
use_cache=True)
outputs = tokenizer.decode(output_ids[0]).strip()
conv.messages[-1][-1] = outputs
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
The code works well on first image input, but fails on the second image input.
Output:
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:29<00:00, 1.94s/it]
Please input image path:/home/aistar/llava/data/view.jpg
<|im_start|>user
: hello
<|im_start|>assistant
: Hello! This is a beautiful image of a wooden dock extending into a serene lake. The calm water reflects the surrounding landscape, which includes a forest and mountains in the distance. The sky is partly cloudy, suggesting a pleasant day. The dock appears to be a quiet spot for relaxation or perhaps a starting point for boating or fishing. It's a peaceful scene that evokes a sense of tranquility and connection with nature.
<|im_start|>user
:
exit...
Please input image path:/home/aistar/llava/data/view.jpg
<|im_start|>user
: hello
<|im_start|>assistant
: Traceback (most recent call last):
File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/aistar/llava/LLaVA/llava/serve/cli_multi_turn.py", line 137, in <module>
main(args)
File "/home/aistar/llava/LLaVA/llava/serve/cli_multi_turn.py", line 107, in main
output_ids = model.generate(
File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/aistar/llava/LLaVA/llava/model/language_model/llava_llama.py", line 125, in generate
) = self.prepare_inputs_labels_for_multimodal(
File "/home/aistar/llava/LLaVA/llava/model/llava_arch.py", line 260, in prepare_inputs_labels_for_multimodal
cur_image_features = image_features[cur_image_idx]
IndexError: list index out of range