torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

How to debug torchtune

Open leo-young opened this issue 9 months ago • 4 comments

Hi, thanks for the framework,

  1. I can't figure out how to debug torchtune, I want to set breakpoint in recipe, but how can I debug.
  2. Do you plan to add rlhf like dpo, ppo ?

leo-young avatar May 12 '24 13:05 leo-young

Hey @leo-young, thanks for the question.

The easiest way to debug I've found is to use the python debugger pdb.

import pdb

# Set a breakpoint in your code
pdb.set_trace()

For RLHF, we do have a recipe for DPO. See https://github.com/pytorch/torchtune/blob/main/recipes/lora_dpo_single_device.py. PPO is actively being discussed on our discord (cc @SalmanMohammadi and @kartikayk). Here is the issue discussing it: #812

RdoubleA avatar May 12 '24 17:05 RdoubleA

hi @leo-young To debug, i.e. directly launch python code instead of using the CLI (which won't hit break points) you can use the following code I wrote (wIll need some changes)

import argparse
from pathlib import Path

from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run
from torchtune._cli.cp import Copy
from torchtune._cli.download import Download
from torchtune._cli.ls import List
from torchtune._cli.run import Run
from torchtune._cli.validate import Validate


MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
RECIPE = "full_finetune_single_device"
MODEL_CONFIG = "llama3/8B_full_single_device"


ACTIVATE_WANDB = True
MAX_STEPS_PER_EPOCH = 20
EPOCHS = 2

HF_CHECKPOINTER = True
# HF_CHECKPOINTER = False

run_command = [
    "run",
    RECIPE,
    "--config",
    MODEL_CONFIG,
]

run_command.append(f"model_name={MODEL_NAME}")

if ACTIVATE_WANDB:
    run_command.append("metric_logger._component_=torchtune.utils.metric_logging.WandBLogger")
    run_command.append("metric_logger.project=torchtune")

if MAX_STEPS_PER_EPOCH:
    run_command.append(f"max_steps_per_epoch={MAX_STEPS_PER_EPOCH}")
if EPOCHS:
    run_command.append(f"epochs={EPOCHS}")

if HF_CHECKPOINTER:
    run_command.append("checkpointer._component_=torchtune.utils.FullModelHFCheckpointer")


class TuneCLIParser:
    """Holds all information related to running the CLI"""

    def __init__(self):
        # Initialize the top-level parser
        self._parser = argparse.ArgumentParser(
            prog="tune",
            description="Welcome to the TorchTune CLI!",
            add_help=True,
        )
        # Default command is to print help
        self._parser.set_defaults(func=lambda args: self._parser.print_help())

        # Add subcommands
        subparsers = self._parser.add_subparsers(title="subcommands")
        Download.create(subparsers)
        List.create(subparsers)
        Copy.create(subparsers)
        Run.create(subparsers)
        Validate.create(subparsers)

    # def parse_args(self) -> argparse.Namespace:
    #     """Parse CLI arguments"""
    #     return self._parser.parse_args()

    def parse_args(self, args=None) -> argparse.Namespace:
        """Parse CLI arguments"""
        return self._parser.parse_args(args)

    def run(self, args: argparse.Namespace) -> None:
        """Execute CLI"""
        args.func(args)


def main():
    parser = TuneCLIParser()
    args = parser.parse_args(run_command)
    parser.run(args)


if __name__ == "__main__":
    main()

optimass avatar May 15 '24 14:05 optimass

@optimass I'm a bit surprised to hear you don't hit breakpoints when running via CLI. Personally I have no issues doing this via the method described by @RdoubleA. Are you sure that it's not due to running on a distributed recipe (there are other issues in that case that are unrelated to CLI)?

ebsmothers avatar May 15 '24 18:05 ebsmothers

could be that I'm using Vscode's visual debugger!

optimass avatar May 15 '24 20:05 optimass

Closing this since I think it's resolved. But feel free to re-open if you run into other issues here!

ebsmothers avatar May 22 '24 04:05 ebsmothers