torchtune
torchtune copied to clipboard
How to debug torchtune
Hi, thanks for the framework,
- I can't figure out how to debug torchtune, I want to set breakpoint in recipe, but how can I debug.
- Do you plan to add rlhf like dpo, ppo ?
Hey @leo-young, thanks for the question.
The easiest way to debug I've found is to use the python debugger pdb
.
import pdb
# Set a breakpoint in your code
pdb.set_trace()
For RLHF, we do have a recipe for DPO. See https://github.com/pytorch/torchtune/blob/main/recipes/lora_dpo_single_device.py. PPO is actively being discussed on our discord (cc @SalmanMohammadi and @kartikayk). Here is the issue discussing it: #812
hi @leo-young To debug, i.e. directly launch python code instead of using the CLI (which won't hit break points) you can use the following code I wrote (wIll need some changes)
import argparse
from pathlib import Path
from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run
from torchtune._cli.cp import Copy
from torchtune._cli.download import Download
from torchtune._cli.ls import List
from torchtune._cli.run import Run
from torchtune._cli.validate import Validate
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
RECIPE = "full_finetune_single_device"
MODEL_CONFIG = "llama3/8B_full_single_device"
ACTIVATE_WANDB = True
MAX_STEPS_PER_EPOCH = 20
EPOCHS = 2
HF_CHECKPOINTER = True
# HF_CHECKPOINTER = False
run_command = [
"run",
RECIPE,
"--config",
MODEL_CONFIG,
]
run_command.append(f"model_name={MODEL_NAME}")
if ACTIVATE_WANDB:
run_command.append("metric_logger._component_=torchtune.utils.metric_logging.WandBLogger")
run_command.append("metric_logger.project=torchtune")
if MAX_STEPS_PER_EPOCH:
run_command.append(f"max_steps_per_epoch={MAX_STEPS_PER_EPOCH}")
if EPOCHS:
run_command.append(f"epochs={EPOCHS}")
if HF_CHECKPOINTER:
run_command.append("checkpointer._component_=torchtune.utils.FullModelHFCheckpointer")
class TuneCLIParser:
"""Holds all information related to running the CLI"""
def __init__(self):
# Initialize the top-level parser
self._parser = argparse.ArgumentParser(
prog="tune",
description="Welcome to the TorchTune CLI!",
add_help=True,
)
# Default command is to print help
self._parser.set_defaults(func=lambda args: self._parser.print_help())
# Add subcommands
subparsers = self._parser.add_subparsers(title="subcommands")
Download.create(subparsers)
List.create(subparsers)
Copy.create(subparsers)
Run.create(subparsers)
Validate.create(subparsers)
# def parse_args(self) -> argparse.Namespace:
# """Parse CLI arguments"""
# return self._parser.parse_args()
def parse_args(self, args=None) -> argparse.Namespace:
"""Parse CLI arguments"""
return self._parser.parse_args(args)
def run(self, args: argparse.Namespace) -> None:
"""Execute CLI"""
args.func(args)
def main():
parser = TuneCLIParser()
args = parser.parse_args(run_command)
parser.run(args)
if __name__ == "__main__":
main()
@optimass I'm a bit surprised to hear you don't hit breakpoints when running via CLI. Personally I have no issues doing this via the method described by @RdoubleA. Are you sure that it's not due to running on a distributed recipe (there are other issues in that case that are unrelated to CLI)?
could be that I'm using Vscode's visual debugger!
Closing this since I think it's resolved. But feel free to re-open if you run into other issues here!