litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Add chat script for adapter checkpoints

Open RDouglasSharp opened this issue 2 years ago • 3 comments

import json
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Optional

import lightning as L
import torch

from generate import generate
from lit_parrot import Tokenizer
from lit_parrot.adapter import Parrot, Config
from lit_parrot.utils import EmptyInitOnDevice, lazy_load, check_valid_checkpoint_dir
sys.path.append(os.path.join(os.path.dirname(__file__), 'scripts'))
from prepare_alpaca import generate_prompt


def main(
    prompt: str = "What would be a good movie to see, and wy do you recommend it?",
    input_string: str = "",
    interactive: bool = False,
    adapter_path: Path = Path("out/adapter/alpaca/lit_model_adapter_finetuned.pth"),
    #checkpoint_dir: Path = Path(f"checkpoints/stabilityai/stablelm-base-alpha-3b"),
    checkpoint_dir: Path = Path(f"checkpoints/stabilityai/stablelm-tuned-alpha-3b"),
    quantize: Optional[str] = None,
    max_new_tokens: int = 100,
    top_k: int = 200,
    temperature: float = 0.8,
    max_seq_length: int = 1250  # set this to what you used during fine tuning
) -> None:
    """Generates a response based on a given instruction and an optional input.
    This script will only work with checkpoints from the instruction-tuned Parrot-Adapter model.
    See `finetune_adapter.py`.

    Args:
        prompt: The prompt/instruction (Alpaca style).
        adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
            `finetune_adapter.py`.
        checkpoint_dir: The path to the checkpoint folder with pretrained Parrot weights.
        input_string: Optional input (Alpaca style).
        quantize: Whether to quantize the model and using which method:
            ``"llm.int8"``: LLM.int8() mode,
            ``"gptq.int4"``: GPTQ 4-bit mode.
        max_new_tokens: The number of generation steps to take.
        top_k: The number of top most probable tokens to consider in the sampling process.
        temperature: A value controlling the randomness of the sampling process. Higher values result in more random
            samples.
        max_seq_length: Optional int idefaults to 1250  # set this to what you used during fine tuning
    """
    check_valid_checkpoint_dir(checkpoint_dir)

    fabric = L.Fabric(devices=1)
    dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

    with open(checkpoint_dir / "lit_config.json") as fp:
        config = Config(**json.load(fp))

    print("Loading model ...", file=sys.stderr)
    t0 = time.time()
    with EmptyInitOnDevice(device=fabric.device, dtype=dtype, quantization_mode=quantize):
        model = Parrot(config)
    with lazy_load(checkpoint_dir / "lit_model.pth") as pretrained_checkpoint, lazy_load(
        adapter_path
    ) as adapter_checkpoint:
        # 1. Load the pretrained weights
        model.load_state_dict(pretrained_checkpoint, strict=False)
        # 2. Load the fine-tuned adapter weights
        model.load_state_dict(adapter_checkpoint, strict=False)

    print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

    model.eval()
    model = fabric.setup(model)

    tokenizer = Tokenizer(checkpoint_dir / "tokenizer.json", checkpoint_dir / "tokenizer_config.json")


    while True:
        if interactive:
            try:
                prompt = input(">> Prompt: ")
            except KeyboardInterrupt:
                break
            if not prompt:
                break
        else:
            print(f'Prompt: {prompt}')

        sample = {"instruction": prompt, "input": input_string}
        prompt = generate_prompt(sample)
        encoded = tokenizer.encode(prompt, device=model.device)
        prompt_length = encoded.size(0)

        t0 = time.perf_counter()
        y = generate(
           model, 
           idx=encoded, 
           max_new_tokens=max_new_tokens, 
           max_seq_length=max_seq_length,
           temperature=temperature, 
           top_k=top_k, 
           eos_id=tokenizer.eos_id
        )
        t = time.perf_counter() - t0

        output = tokenizer.decode(y)
        output = output.split("### Response:")[1].strip()
        print(output)

        tokens_generated = y.size(0) - prompt_length
        print(f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
        if fabric.device.type == "cuda":
            print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)

        if not interactive:
            break


if __name__ == "__main__":
    from jsonargparse import CLI

    torch.set_float32_matmul_precision("high")
    warnings.filterwarnings(
        # Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31
        "ignore",
        message="ComplexHalf support is experimental and many operators don't support it yet",
    )
    CLI(main)

RDouglasSharp avatar May 20 '23 15:05 RDouglasSharp

Instead of an --interactive flag, it would be better to add a chat_adapter.py script that supports it and streaming the output.

Since this adds quite a bit of logic, it's better to keep the scripts separate.

Closing https://github.com/Lightning-AI/lit-parrot/issues/78 in favor of this

carmocca avatar May 21 '23 19:05 carmocca

I came here to see if anyone else was making this feature.

@agmo1993 does your adapter.py also work with v2 adapters?

iskandr avatar Jun 07 '23 16:06 iskandr

Not yet, but can try to make it work within the same script

agmo1993 avatar Jun 08 '23 07:06 agmo1993