Error while running inference with generate_v2.py after one generation
Hey, I made a small change in generate_v2.py to run a loop to the whole test set. I am getting some error because of cacheing I guess. I have pasted the error message and code below which i am getting after one generation.
Code
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import sys
import time
from typing import Any, Dict, List
import torch
from omegaconf import DictConfig, OmegaConf
from torchtune import config, training, utils
from torchtune.data import load_image, Message, padded_collate_tiled_images_and_mask
from torchtune.generation import sample
from torchtune.modules.transforms import Transform
import json
from evalplus.data import get_human_eval_plus
from datasets import load_dataset, concatenate_datasets
import pandas as pd
class SingleTurnYAMLToMessages(Transform):
"""
Converts a single turn conversation in YAML format to a list of messages.
Expects the YAML to look like:
system: You are a helpful AI assistant.
user: What is the capital of France?
or if it includes an image:
system: You are a helpful AI assistant.
user:
image: url or path_to_image
text: Describe the image in detail.
"""
def __call__(self, prompt: Dict[str, Any]) -> List[Message]:
messages = []
# Iterate through roles and add content
for role, content in prompt.items():
if isinstance(content, str):
new_content = [{"type": "text", "content": content}]
else:
assert (
"image" in content.keys()
), "Multiple entries per role expect an image key"
image_loc = content["image"]
image = load_image(image_loc)
new_content = [
{"type": "image", "content": image},
{"type": "text", "content": content["text"]},
]
messages.append(Message(role=role, content=new_content))
# Finally, add an empty assistant message to kick-start generation
messages.append(Message(role="assistant", content=""))
return messages
class InferenceRecipe:
"""
Recipe for generating tokens from a dense Transformer-based LLM.
This works for text-only generation and image-text generation.
This *does not* currently support the following features:
- torch.compile
- quantization through torchao
- multi-GPU generation
- batch generation
"""
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
training.set_seed(seed=cfg.seed)
def setup(self, cfg: DictConfig) -> None:
"""Setup the model and transforms."""
# Load checkpointer and state_dict
_checkpointer = config.instantiate(cfg.checkpointer)
_ckpt_dict = _checkpointer.load_checkpoint()
# Instantiate model
with training.set_default_dtype(self._dtype), self._device:
model = config.instantiate(cfg.model)
model.load_state_dict(_ckpt_dict[training.MODEL_KEY])
self.model = model
self._logger.info(f"Model was initialized with precision {self._dtype}.")
# Instantiate transforms
self.model_transform = config.instantiate(cfg.tokenizer)
self.to_messages = SingleTurnYAMLToMessages()
def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
"""Logs the following metrics: total time for inference, tokens/sec,
bandwidth achieved, and max memory allocated.
Feel free to modify this function to log additional metrics.
"""
model_size = sum(
[
p.numel() * p.dtype.itemsize
for p in itertools.chain(self.model.parameters(), self.model.buffers())
]
)
self._logger.info(
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec"
)
self._logger.info(
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
)
self._logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)
@torch.inference_mode()
def generate(self, cfg, prompt1, task_id):
"""The main entry point for generating tokens from a prompt."""
# 1. Convert input to messages
prom={'system': 'Please provide a self-contained Python script that solves the following problem in a markdown code block', 'user': prompt1}
messages = self.to_messages(OmegaConf.to_container(OmegaConf.create(prom)))
is_multimodal_input = any([m.contains_media for m in messages])
# 2. Apply model transform
model_inputs = self.model_transform({"messages": messages}, inference=True)
seq_len = len(model_inputs["tokens"])
total_response_length = seq_len + cfg.max_new_tokens
# 3. Setup KV cache
with self._device:
self.model.setup_caches(
batch_size=1,
dtype=self._dtype,
encoder_max_seq_len=(
self.model_transform.image_seq_len if is_multimodal_input else None
),
decoder_max_seq_len=total_response_length,
)
# 4. Pre-allocate causal mask and input_pos
causal_mask = torch.tril(
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
device=self._device,
)
)
input_pos = torch.arange(total_response_length)
# 5. Collate to batch size of 1 and tensor-ify
batch = {}
if is_multimodal_input:
batch = padded_collate_tiled_images_and_mask(
[model_inputs], pad_direction="left", pad_max_images=1
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
prompt = batch.pop("tokens").to(self._device)
else:
prompt = torch.tensor(
model_inputs["tokens"], device=self._device
).unsqueeze(0)
batch["mask"] = causal_mask[None, :seq_len]
batch["input_pos"] = input_pos[None, :seq_len]
utils.batch_to_device(batch, self._device)
# 6. Prefill step
generated_tokens = []
t0 = time.perf_counter()
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
generated_tokens.append(token.item())
if is_multimodal_input:
# Don't need image info b/c we only support 1 image and it's been
# processed by the model now
batch.pop("encoder_input")
batch["encoder_mask"] = batch["encoder_mask"][:, -1:]
# 7. Continue generating
for i in range(cfg.max_new_tokens):
# Update position and mask for incremental decoding
batch["input_pos"] = input_pos[None, seq_len]
batch["mask"] = causal_mask[None, seq_len, None, :]
if token.item() in self.model_transform.stop_tokens:
break
logits = self.model(token, **batch)[:, -1]
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
generated_tokens.append(token.item())
seq_len += 1
t = time.perf_counter() - t0
# 8. Translate tokens back to text
decoded = self.model_transform.decode(generated_tokens)
self._logger.info(f"\n\n{decoded}\n")
result={
"task_id": task_id,
"solution": decoded
}
append_to_json("/home/toolkit/scratch/LLMcode/Checkpoints/Fine_tuning_models-3B-PT/output.json", result)
# 9. Log metrics
tokens_per_second = len(generated_tokens) / t
self.log_metrics(total_time=t, tokens_per_second=tokens_per_second)
def append_to_json(file_path, data):
try:
# Read existing data from the file
with open(file_path, "r") as f:
file_data = json.load(f)
except FileNotFoundError:
# If the file does not exist, create an empty list
file_data = []
# Append the new data to the existing list
file_data.append(data)
# Write the updated data back to the file
with open(file_path, "w") as f:
json.dump(file_data, f, indent=4)
def prepare_code_sample(code_gen, id) -> Dict[str, Any]:
input_pr=code_gen[code_gen["task_id"]==id]["text"].values[0].split("### Answer:\nBelow is a Python script with a self-contained function that solves the problem and passes corresponding tests:")[0].split("Please provide a self-contained Python script that solves the following problem in a markdown code block:")[1].strip()
return input_pr
@config.parse
def main(cfg: DictConfig) -> None:
config.log_config(recipe_name="InferenceRecipe", cfg=cfg)
recipe = InferenceRecipe(cfg=cfg)
recipe.setup(cfg=cfg)
ds=load_dataset("Dataset")
new = concatenate_datasets([ds['train']])
code_gen=pd.DataFrame(new)
for id in list(code_gen["task_id"].values):
prompt=prepare_code_sample(code_gen,id)
recipe.generate(cfg, prompt,id)
if __name__ == "__main__":
sys.exit(main())
Error message
INFO:torchtune.utils._logging:Time for inference: 4.34 sec total, 7.83 tokens/sec
INFO:torchtune.utils._logging:Bandwidth achieved: 51.59 GB/s
INFO:torchtune.utils._logging:Max memory allocated: 6.67 GB
hhihihihihihi
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
WARNING:torchtune.modules.attention:Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping.
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/toolkit/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/home/toolkit/.conda/envs/torch/bin/tune", line 8, in <module>
sys.exit(main())
^^^^^^
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/_cli/tune.py", line 49, in main
parser.run(args)
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/_cli/tune.py", line 43, in run
args.func(args)
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/_cli/run.py", line 187, in _run_cmd
self._run_single_device(args)
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/_cli/run.py", line 96, in _run_single_device
runpy.run_path(str(args.recipe), run_name="__main__")
File "<frozen runpy>", line 291, in run_path
File "<frozen runpy>", line 98, in _run_module_code
File "<frozen runpy>", line 88, in _run_code
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/recipes/dev/generate_v2.py", line 272, in <module>
sys.exit(main())
^^^^^^
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/config/_parse.py", line 99, in wrapper
sys.exit(recipe_main(conf))
^^^^^^^^^^^^^^^^^
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/recipes/dev/generate_v2.py", line 267, in main
recipe.generate(cfg, prompt,id)
File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/recipes/dev/generate_v2.py", line 176, in generate
logits = self.model(prompt, **batch)[:, -1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/modules/transformer.py", line 599, in forward
h = layer(
^^^^^^
File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/modules/transformer.py", line 114, in forward
attn_out = self.attn(h, h, mask=mask, input_pos=input_pos)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/.conda/envs/torch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/modules/attention.py", line 297, in forward
output = self._attention_call(
^^^^^^^^^^^^^^^^^^^^^
File "/home/toolkit/scratch/LLMcode/Train/torchtune-2/torchtune/torchtune/modules/attention_utils.py", line 236, in _attention_call
return nn.functional.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The expanded size of the tensor (285) must match the existing size (256) at non-singleton dimension 3. Target sizes: [1, 24, 56, 285]. Tensor sizes: [1, 1, 56, 256]
Ahh looks like you are attempting to modify the generation script to run inference multiple times - I'm super happy you are hacking on our recipes, that's exactly what they're for!
We utilize a very basic static key-value cache in our library. This means that it's allocated once for a fixed size. So when you try to run your second inference, it will complain that you've already setup caches and cannot do it again. And when you finally have a prompt that is too long, the cache won't be big enough and it'll error out.
There's a couple ways to fix this:
- Remove KV-caching from your script. This would slow down inference (possibly a lot), but avoid you having to deal with any caching logic.
- Modify your KV-cache to the longest length in the your dataset. Be aware, this could increase your memory usage quite a bit. Then, after each inference, call
model.reset_caches(), which will zero out the KV-Cache. You will have to move some of the caching logic around b/c it looks like you load in the dataset and iterate over samples outside of therecipe.generatefunction. Here's some pseudo-code of what that might look like:
code_gen=pd.DataFrame(new)
max_prompt_len = max(code_gen["prompt"]) # idk what column it would be
model.setup_caches(
batch_size=1,
dtype=self._dtype,
encoder_max_seq_len=None, # Doesn't look like we have multimodal input
decoder_max_seq_len=max_prompt_len + cfg.max_new_tokens
)
for id in list(code_gen["task_id"].values):
recipe.generate(cfg, prompt, id)
model.reset_caches()
Let me know how this goes.