LLaVA-NeXT
LLaVA-NeXT copied to clipboard
[Help for beginners] Code for running interleave
Great project, appreciate it highly :) To give something back (not much, but may help some beginners to get started) here is my code for using interleave without gradio implemented as a class and simplified. It does not support videos, if you like to have this feature, just extend the code accordingly.
I mainly replaced the nested 'history' list by a stack (implementedvia pythons dequeue) for a typed, more efficient (ok that just matters for reaaaally long LLM conversations), and more beginner friendly implementation to understand whats going on.
import os
import requests
import torch
from collections import deque
# import cv2
from PIL import Image
from io import BytesIO
from args import LLavaNextArgs
from utils import (
determine_llava_next_conv_mode,
)
from llava.utils import disable_torch_init
from llava.model.builder import load_pretrained_model
from llava.mm_utils import (
get_model_name_from_path,
process_images,
tokenizer_image_token,
)
from transformers import TextStreamer
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
)
# from llava import conversation as conversation_lib
from _conversation import conv_templates, SeparatorStyle, Conversation
# from llava.conversation import conv_qwen
from llava.mm_utils import (
tokenizer_image_token,
get_model_name_from_path,
KeywordsStoppingCriteria,
)
def is_valid_video_filename(name):
video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
ext = name.split(".")[-1].lower()
if ext in video_extensions:
return True
else:
return False
# def sample_frames(video_file, num_frames):
# video = cv2.VideoCapture(video_file)
# total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
# interval = total_frames // num_frames
# frames = []
# for i in range(total_frames):
# ret, frame = video.read()
# pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# if not ret:
# continue
# if i % interval == 0:
# frames.append(pil_img)
# video.release()
# return frames
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
if response.status_code == 200:
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
print("failed to load the image")
else:
print("Load image from local file")
print(image_file)
image = Image.open(image_file).convert("RGB")
return image
class ImageHistory(deque[list[Image.Image]]):
"""
This class implements a 'transactional'-style of bookkeeping images
while also maintaining a linear list of all images
"""
def __init__(self) -> None:
self.image_list = []
def get_image_count(self) -> int:
return len(self.image_list)
def get_all_images(self) -> list[Image.Image]:
return self.image_list
def append(self, __x: list[Image.Image]) -> None:
self.image_list.extend(__x)
return super().append(__x)
def appendleft(self, __x: list[Image.Image]) -> None:
raise NotImplementedError("Not supported")
def extend(self) -> None:
raise NotImplementedError("Not supported")
def extendleft(self) -> None:
raise NotImplementedError("Not supported")
def insert(self, __i: int, __x: list[Image.Image]) -> None:
raise NotImplementedError("Not supported")
def pop(self) -> list[Image.Image]:
pop_list = super().pop()
for x in pop_list:
self.image_list.remove(x)
return pop_list
def popleft(self) -> list[Image.Image]:
raise NotImplementedError("Not supported")
def clear(self) -> None:
self.image_list.clear()
return super().clear()
class LLaVAInterleavePredictor(object):
def __init__(self, args: LLavaNextArgs, device="cuda") -> None:
disable_torch_init()
self.device = device
self.args = args
self.model_name = get_model_name_from_path(args.model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = (
load_pretrained_model(
args.model_path,
args.model_base,
self.model_name,
args.load_8bit,
args.load_4bit,
)
)
# self.model.eval()
# self.model.tie_weights()
conv_mode = determine_llava_next_conv_mode(self.model_name)
if args.conv_mode is not None and conv_mode != args.conv_mode:
print(
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
conv_mode, args.conv_mode, args.conv_mode
)
)
else:
args.conv_mode = conv_mode
self.conv_mode = conv_mode
self.conv: Conversation = None
self.reset_conversation()
self.num_frames = args.num_frames
self.image_history = ImageHistory()
def reset_conversation(self) -> None:
self.image_history.clear()
self.conv = conv_templates[self.args.conv_mode].copy()
def eval(self, message: str, new_images: list[Image.Image]) -> None:
self.image_history.append(new_images)
images = self.image_history.get_all_images()
# image_tensor = [
# self.image_processor.preprocess(img, return_tensors="pt")["pixel_values"][0]
# .half()
# .to(self.model.device)
# for img in images
# ]
image_tensor = process_images(images, self.image_processor, self.model.config)
image_tensor = [
_image.to(dtype=torch.float16, device="cuda") for _image in image_tensor
]
image_sizes = [x.size for x in images]
image_tensor = torch.stack(image_tensor)
# Unary token identifier for an image
num_new_images = len(new_images)
image_token = DEFAULT_IMAGE_TOKEN * (num_new_images)
inp = image_token + "\n" + message
self.conv.append_message(self.conv.roles[0], inp)
self.conv.append_message(self.conv.roles[1], None)
prompt = self.conv.get_prompt()
# For debugging purposes (:
print("Prompt: " + prompt)
input_ids = (
tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.to("cuda")
)
stop_str = (
self.conv.sep
if self.conv.sep_style != SeparatorStyle.TWO
else self.conv.sep2
)
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(
keywords, self.tokenizer, input_ids
)
streamer = TextStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
image_sizes=image_sizes,
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
streamer=streamer,
use_cache=False,
# pad_token_id=self.tokenizer.pad_token_id,
stopping_criteria=[stopping_criteria],
)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[
0
].strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
# Append answer to conv object
self.conv.messages[-1][-1] = outputs
print("\nOutputs: " + outputs)
return outputs
if __name__ == "__main__":
# For test-purposes :)
model_path = "lmms-lab/llava-next-interleave-qwen-7b"
# model_path = "lmms-lab/llava-next-interleave-qwen-7b-dpo"
# model_path = "lmms-lab/llava-next-interleave-7b"
# model_path = "lmms-lab/llama3-llava-next-8b"
args = LLavaNextArgs()
args.model_path = model_path
predictor = LLaVAInterleavePredictor(args)
img_1_path = os.path.join(os.getcwd(), "../../test_img.png")
img_2_path = os.path.join(os.getcwd(), "../../pikachu.png")
img_1 = Image.open(img_1_path).convert("RGB")
img_2 = Image.open(img_2_path).convert("RGB")
output = predictor.eval(
"Please describe what you see the provided image",
[img_1],
)
output = predictor.eval(
"Please describe what you see in the second provided image",
[img_2],
)
and the utility class LlavaNextArgs (I like everything being typed)
from dataclasses import dataclass
@dataclass
class LLavaNextArgs:
address: str = "0.0.0.0"
port: str = "9999"
model_path: str = "llms-lab/llava-next-interleave-qwen-7b-dpo"
num_gpus: int = 1
conv_mode: str | None = None
temperature: float = 0.2
max_new_tokens: int = 512
num_frames: int = 16
load_8bit: bool | None = None
load_4bit: bool | None = None
model_base: str | None = None
debug: bool = False