LLaVA-NeXT icon indicating copy to clipboard operation
LLaVA-NeXT copied to clipboard

[Help for beginners] Code for running interleave

Open uahic opened this issue 6 months ago • 0 comments

Great project, appreciate it highly :) To give something back (not much, but may help some beginners to get started) here is my code for using interleave without gradio implemented as a class and simplified. It does not support videos, if you like to have this feature, just extend the code accordingly.

I mainly replaced the nested 'history' list by a stack (implementedvia pythons dequeue) for a typed, more efficient (ok that just matters for reaaaally long LLM conversations), and more beginner friendly implementation to understand whats going on.

import os
import requests
import torch

from collections import deque

# import cv2
from PIL import Image
from io import BytesIO

from args import LLavaNextArgs
from utils import (
    determine_llava_next_conv_mode,
)
from llava.utils import disable_torch_init
from llava.model.builder import load_pretrained_model


from llava.mm_utils import (
    get_model_name_from_path,
    process_images,
    tokenizer_image_token,
)
from transformers import TextStreamer
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
)


# from llava import conversation as conversation_lib

from _conversation import conv_templates, SeparatorStyle, Conversation

# from llava.conversation import conv_qwen
from llava.mm_utils import (
    tokenizer_image_token,
    get_model_name_from_path,
    KeywordsStoppingCriteria,
)


def is_valid_video_filename(name):
    video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
    ext = name.split(".")[-1].lower()
    if ext in video_extensions:
        return True
    else:
        return False


# def sample_frames(video_file, num_frames):
#     video = cv2.VideoCapture(video_file)
#     total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
#     interval = total_frames // num_frames
#     frames = []
#     for i in range(total_frames):
#         ret, frame = video.read()
#         pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
#         if not ret:
#             continue
#         if i % interval == 0:
#             frames.append(pil_img)
#     video.release()
#     return frames


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        if response.status_code == 200:
            image = Image.open(BytesIO(response.content)).convert("RGB")
        else:
            print("failed to load the image")
    else:
        print("Load image from local file")
        print(image_file)
        image = Image.open(image_file).convert("RGB")

    return image


class ImageHistory(deque[list[Image.Image]]):
    """
    This class implements a 'transactional'-style of bookkeeping images
    while also maintaining a linear list of all images
    """
    def __init__(self) -> None:
        self.image_list = []

    def get_image_count(self) -> int:
        return len(self.image_list)

    def get_all_images(self) -> list[Image.Image]:
        return self.image_list

    def append(self, __x: list[Image.Image]) -> None:
        self.image_list.extend(__x)
        return super().append(__x)

    def appendleft(self, __x: list[Image.Image]) -> None:
        raise NotImplementedError("Not supported")

    def extend(self) -> None:
        raise NotImplementedError("Not supported")

    def extendleft(self) -> None:
        raise NotImplementedError("Not supported")

    def insert(self, __i: int, __x: list[Image.Image]) -> None:
        raise NotImplementedError("Not supported")
    
    def pop(self) -> list[Image.Image]:
        pop_list = super().pop()
        for x in pop_list:
            self.image_list.remove(x)
        return pop_list

    def popleft(self) -> list[Image.Image]:
        raise NotImplementedError("Not supported")

    def clear(self) -> None:
        self.image_list.clear()
        return super().clear()
    
    

class LLaVAInterleavePredictor(object):
    def __init__(self, args: LLavaNextArgs, device="cuda") -> None:
        disable_torch_init()
        self.device = device
        self.args = args
        self.model_name = get_model_name_from_path(args.model_path)
        self.tokenizer, self.model, self.image_processor, self.context_len = (
            load_pretrained_model(
                args.model_path,
                args.model_base,
                self.model_name,
                args.load_8bit,
                args.load_4bit,
            )
        )

        # self.model.eval()
        # self.model.tie_weights()

        conv_mode = determine_llava_next_conv_mode(self.model_name)
        if args.conv_mode is not None and conv_mode != args.conv_mode:
            print(
                "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                    conv_mode, args.conv_mode, args.conv_mode
                )
            )
        else:
            args.conv_mode = conv_mode

        self.conv_mode = conv_mode
        self.conv: Conversation = None
        self.reset_conversation()
        self.num_frames = args.num_frames

        self.image_history = ImageHistory()

    def reset_conversation(self) -> None:
        self.image_history.clear()
        self.conv = conv_templates[self.args.conv_mode].copy()


    def eval(self, message: str, new_images: list[Image.Image]) -> None:

        self.image_history.append(new_images)
        images = self.image_history.get_all_images()
  
        # image_tensor = [
        #     self.image_processor.preprocess(img, return_tensors="pt")["pixel_values"][0]
        #     .half()
        #     .to(self.model.device)
        #     for img in images
        # ]
        image_tensor = process_images(images, self.image_processor, self.model.config)
        image_tensor = [
            _image.to(dtype=torch.float16, device="cuda") for _image in image_tensor
        ]

        image_sizes = [x.size for x in images]

        image_tensor = torch.stack(image_tensor)

        # Unary token identifier for an image
        num_new_images = len(new_images)
        image_token = DEFAULT_IMAGE_TOKEN * (num_new_images)

        inp = image_token + "\n" + message

        self.conv.append_message(self.conv.roles[0], inp)
        self.conv.append_message(self.conv.roles[1], None)
        prompt = self.conv.get_prompt()

        # For debugging purposes (:
        print("Prompt: " + prompt)

        input_ids = (
            tokenizer_image_token(
                prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
            )
            .unsqueeze(0)
            .to("cuda")
        )

        stop_str = (
            self.conv.sep
            if self.conv.sep_style != SeparatorStyle.TWO
            else self.conv.sep2
        )
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(
            keywords, self.tokenizer, input_ids
        )
        streamer = TextStreamer(
            self.tokenizer, skip_prompt=True, skip_special_tokens=True
        )

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=image_tensor,
                image_sizes=image_sizes,
                do_sample=True,
                temperature=0.2,
                max_new_tokens=1024,
                streamer=streamer,
                use_cache=False,
                # pad_token_id=self.tokenizer.pad_token_id,
                stopping_criteria=[stopping_criteria],
            )
        outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[
            0
        ].strip()
        if outputs.endswith(stop_str):
            outputs = outputs[: -len(stop_str)]

        # Append answer to conv object
        self.conv.messages[-1][-1] = outputs
        print("\nOutputs: " + outputs)
        return outputs


if __name__ == "__main__":
    # For test-purposes :)

    model_path = "lmms-lab/llava-next-interleave-qwen-7b"
    # model_path = "lmms-lab/llava-next-interleave-qwen-7b-dpo"
    # model_path = "lmms-lab/llava-next-interleave-7b"
    # model_path = "lmms-lab/llama3-llava-next-8b"
    args = LLavaNextArgs()
    args.model_path = model_path
    predictor = LLaVAInterleavePredictor(args)

    img_1_path = os.path.join(os.getcwd(), "../../test_img.png")
    img_2_path = os.path.join(os.getcwd(), "../../pikachu.png")
    img_1 = Image.open(img_1_path).convert("RGB")
    img_2 = Image.open(img_2_path).convert("RGB")

    output = predictor.eval(
        "Please describe what you see the provided image",
        [img_1],
    )

    output = predictor.eval(
        "Please describe what you see in the second provided image",
        [img_2],
    )

and the utility class LlavaNextArgs (I like everything being typed)

from dataclasses import dataclass


@dataclass
class LLavaNextArgs:
    address: str = "0.0.0.0"
    port: str = "9999"
    model_path: str = "llms-lab/llava-next-interleave-qwen-7b-dpo"
    num_gpus: int = 1
    conv_mode: str | None = None
    temperature: float = 0.2
    max_new_tokens: int = 512
    num_frames: int = 16
    load_8bit: bool | None = None
    load_4bit: bool | None = None
    model_base: str | None = None
    debug: bool = False

uahic avatar Aug 07 '24 07:08 uahic