torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Error while running inference with generate_v2.py after one generation

Open Vattikondadheeraj opened this issue 1 year ago • 1 comments

Hey, I made a small change in generate_v2.py to run a loop to the whole test set. I am getting some error because of cacheing I guess. I have pasted the error message and code below which i am getting after one generation.

Code

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import sys
import time
from typing import Any, Dict, List

import torch
from omegaconf import DictConfig, OmegaConf

from torchtune import config, training, utils
from torchtune.data import load_image, Message, padded_collate_tiled_images_and_mask

from torchtune.generation import sample

from torchtune.modules.transforms import Transform
import json
from evalplus.data import get_human_eval_plus
from datasets import load_dataset, concatenate_datasets
import pandas as pd




class SingleTurnYAMLToMessages(Transform):
    """
    Converts a single turn conversation in YAML format to a list of messages.

    Expects the YAML to look like:
        system: You are a helpful AI assistant.
        user: What is the capital of France?

    or if it includes an image:
        system: You are a helpful AI assistant.
        user:
            image: url or path_to_image
            text: Describe the image in detail.
    """

    def __call__(self, prompt: Dict[str, Any]) -> List[Message]:
        messages = []

        # Iterate through roles and add content
        for role, content in prompt.items():
            if isinstance(content, str):
                new_content = [{"type": "text", "content": content}]
            else:
                assert (
                    "image" in content.keys()
                ), "Multiple entries per role expect an image key"
                image_loc = content["image"]
                image = load_image(image_loc)
                new_content = [
                    {"type": "image", "content": image},
                    {"type": "text", "content": content["text"]},
                ]
            messages.append(Message(role=role, content=new_content))

        # Finally, add an empty assistant message to kick-start generation
        messages.append(Message(role="assistant", content=""))
        return messages


class InferenceRecipe:
    """
    Recipe for generating tokens from a dense Transformer-based LLM.
    This works for text-only generation and image-text generation.

    This *does not* currently support the following features:
        - torch.compile
        - quantization through torchao
        - multi-GPU generation
        - batch generation
    """

    def __init__(self, cfg: DictConfig) -> None:
        self._device = utils.get_device(device=cfg.device)
        self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
        self._logger = utils.get_logger(cfg.log_level)
        training.set_seed(seed=cfg.seed)

    def setup(self, cfg: DictConfig) -> None:
        """Setup the model and transforms."""
        # Load checkpointer and state_dict
        _checkpointer = config.instantiate(cfg.checkpointer)
        _ckpt_dict = _checkpointer.load_checkpoint()

        # Instantiate model
        with training.set_default_dtype(self._dtype), self._device:
            model = config.instantiate(cfg.model)
        model.load_state_dict(_ckpt_dict[training.MODEL_KEY])
        self.model = model
        self._logger.info(f"Model was initialized with precision {self._dtype}.")

        # Instantiate transforms
        self.model_transform = config.instantiate(cfg.tokenizer)
        self.to_messages = SingleTurnYAMLToMessages()

    def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
        """Logs the following metrics: total time for inference, tokens/sec,
        bandwidth achieved, and max memory allocated.

        Feel free to modify this function to log additional metrics.
        """
        model_size = sum(
            [
                p.numel() * p.dtype.itemsize
                for p in itertools.chain(self.model.parameters(), self.model.buffers())
            ]
        )
        self._logger.info(
            f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec"
        )
        self._logger.info(
            f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
        )
        self._logger.info(
            f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
        )

    @torch.inference_mode()
    def generate(self, cfg, prompt1, task_id):
        """The main entry point for generating tokens from a prompt."""
        # 1. Convert input to messages
        prom={'system': 'Please provide a self-contained Python script that solves the following problem in a markdown code block', 'user': prompt1}
        messages = self.to_messages(OmegaConf.to_container(OmegaConf.create(prom)))
        is_multimodal_input = any([m.contains_media for m in messages])

        # 2. Apply model transform
        model_inputs = self.model_transform({"messages": messages}, inference=True)
        seq_len = len(model_inputs["tokens"])
        total_response_length = seq_len + cfg.max_new_tokens

        # 3. Setup KV cache
        with self._device:
            self.model.setup_caches(
                batch_size=1,
                dtype=self._dtype,
                encoder_max_seq_len=(
                    self.model_transform.image_seq_len if is_multimodal_input else None
                ),
                decoder_max_seq_len=total_response_length,
            )

        # 4. Pre-allocate causal mask and input_pos
        causal_mask = torch.tril(
            torch.ones(
                size=(total_response_length, total_response_length),
                dtype=torch.bool,
                device=self._device,
            )
        )
        input_pos = torch.arange(total_response_length)

        # 5. Collate to batch size of 1 and tensor-ify
        batch = {}
        if is_multimodal_input:
            batch = padded_collate_tiled_images_and_mask(
                [model_inputs], pad_direction="left", pad_max_images=1
            )
            batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
            prompt = batch.pop("tokens").to(self._device)
        else:
            prompt = torch.tensor(
                model_inputs["tokens"], device=self._device
            ).unsqueeze(0)
        batch["mask"] = causal_mask[None, :seq_len]
        batch["input_pos"] = input_pos[None, :seq_len]
        utils.batch_to_device(batch, self._device)

        # 6. Prefill step
        generated_tokens = []
        t0 = time.perf_counter()
        logits = self.model(prompt, **batch)[:, -1]
        token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
        generated_tokens.append(token.item())

        if is_multimodal_input:
            # Don't need image info b/c we only support 1 image and it's been
            # processed by the model now
            batch.pop("encoder_input")
            batch["encoder_mask"] = batch["encoder_mask"][:, -1:]

        # 7. Continue generating
        for i in range(cfg.max_new_tokens):

            # Update position and mask for incremental decoding
            batch["input_pos"] = input_pos[None, seq_len]
            batch["mask"] = causal_mask[None, seq_len, None, :]

            if token.item() in self.model_transform.stop_tokens:
                break

            logits = self.model(token, **batch)[:, -1]
            token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
            generated_tokens.append(token.item())
            seq_len += 1

        t = time.perf_counter() - t0

        # 8. Translate tokens back to text
        decoded = self.model_transform.decode(generated_tokens)
        self._logger.info(f"\n\n{decoded}\n")

        result={
            "task_id": task_id,
            "solution": decoded
        }
        append_to_json("/home/toolkit/scratch/LLMcode/Checkpoints/Fine_tuning_models-3B-PT/output.json", result)

        # 9. Log metrics
        tokens_per_second = len(generated_tokens) / t
        self.log_metrics(total_time=t, tokens_per_second=tokens_per_second)


def append_to_json(file_path, data):
        try:
        # Read existing data from the file
            with open(file_path, "r") as f:
                file_data = json.load(f)
        
        except FileNotFoundError:
        # If the file does not exist, create an empty list
            file_data = []

    # Append the new data to the existing list
        file_data.append(data)

    # Write the updated data back to the file
        with open(file_path, "w") as f:
            json.dump(file_data, f, indent=4)


def prepare_code_sample(code_gen, id) -> Dict[str, Any]:
        input_pr=code_gen[code_gen["task_id"]==id]["text"].values[0].split("### Answer:\nBelow is a Python script with a self-contained function that solves the problem and passes corresponding tests:")[0].split("Please provide a self-contained Python script that solves the following problem in a markdown code block:")[1].strip()
      
        
        return input_pr


@config.parse
def main(cfg: DictConfig) -> None:
    config.log_config(recipe_name="InferenceRecipe", cfg=cfg)
    recipe = InferenceRecipe(cfg=cfg)
    recipe.setup(cfg=cfg)

    ds=load_dataset("Dataset")
    new = concatenate_datasets([ds['train']])
    code_gen=pd.DataFrame(new)

    for id in list(code_gen["task_id"].values):
        prompt=prepare_code_sample(code_gen,id)
        recipe.generate(cfg, prompt,id)



if __name__ == "__main__":
    sys.exit(main())


Error message

INFO:torchtune.utils._logging:Time for inference: 4.34 sec total, 7.83 tokens/sec
INFO:torchtune.utils._logging:Bandwidth achieved: 51.59 GB/s
INFO:torchtune.utils._logging:Max memory allocated: 6.67 GB
hhihihihihihi
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/toolkit/.conda/envs/torch/bin/tune", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/_cli/run.py", line 187, in _run_cmd
    self._run_single_device(args)
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/_cli/run.py", line 96, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "<frozen runpy>", line 291, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/recipes/dev/generate_v2.py", line 272, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/config/_parse.py", line 99, in wrapper
    sys.exit(recipe_main(conf))
             ^^^^^^^^^^^^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/recipes/dev/generate_v2.py", line 267, in main
    recipe.generate(cfg, prompt,id)
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/recipes/dev/generate_v2.py", line 176, in generate
    logits = self.model(prompt, **batch)[:, -1]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/modules/transformer.py", line 599, in forward
    h = layer(
        ^^^^^^
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/modules/transformer.py", line 114, in forward
    attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/modules/attention.py", line 297, in forward
    output = self._attention_call(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/modules/attention_utils.py", line 236, in _attention_call
    return nn.functional.scaled_dot_product_attention(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The expanded size of the tensor (285) must match the existing size (256) at non-singleton dimension 3.  Target sizes: [1, 24, 56, 285].  Tensor sizes: [1, 1, 56, 256]

Vattikondadheeraj avatar Oct 05 '24 20:10 Vattikondadheeraj

Ahh looks like you are attempting to modify the generation script to run inference multiple times - I'm super happy you are hacking on our recipes, that's exactly what they're for!

We utilize a very basic static key-value cache in our library. This means that it's allocated once for a fixed size. So when you try to run your second inference, it will complain that you've already setup caches and cannot do it again. And when you finally have a prompt that is too long, the cache won't be big enough and it'll error out.

There's a couple ways to fix this:

  1. Remove KV-caching from your script. This would slow down inference (possibly a lot), but avoid you having to deal with any caching logic.
  2. Modify your KV-cache to the longest length in the your dataset. Be aware, this could increase your memory usage quite a bit. Then, after each inference, call model.reset_caches(), which will zero out the KV-Cache. You will have to move some of the caching logic around b/c it looks like you load in the dataset and iterate over samples outside of the recipe.generate function. Here's some pseudo-code of what that might look like:

code_gen=pd.DataFrame(new)
max_prompt_len = max(code_gen["prompt"]) # idk what column it would be

model.setup_caches(
	batch_size=1, 
	dtype=self._dtype,
    encoder_max_seq_len=None, # Doesn't look like we have multimodal input
    decoder_max_seq_len=max_prompt_len + cfg.max_new_tokens
)

for id in list(code_gen["task_id"].values):
    recipe.generate(cfg, prompt, id)
	model.reset_caches()

Let me know how this goes.

joecummings avatar Oct 07 '24 09:10 joecummings