Automodel icon indicating copy to clipboard operation
Automodel copied to clipboard

feat: callback support #805

Open yuhezhang-ai opened this issue 1 month ago • 1 comments

Implements PyTorch Lightning-style callback system for Automodel recipes to enable custom integrations (e.g., Customizer metrics reporting, custom monitoring).

Key Features

  • 6 lifecycle hooks: on_train_start, on_train_batch_end, on_validation_end, on_save_checkpoint, on_exception, on_train_end
  • @rank_zero_only decorator for distributed training convenience
  • Full training context passed to callbacks: recipe object, MetricsSample objects, checkpoint info, etc.
  • Programmatic API: Callbacks passed directly to recipe constructors

Changes

  • Added Callback base class and CallbackRunner in nemo_automodel/components/callbacks/
  • Integrated callbacks into all recipe classes (LLM finetune, VLM finetune, sequence classification, knowledge distillation)
  • Comprehensive example: examples/llm_finetune/finetune_with_callback.py
  • 9 unit tests covering all hooks and distributed behavior
  • Documentation: docs/guides/callbacks.md

Usage

from nemo_automodel.components.callbacks import Callback

class MetricsReporter(Callback):
    def on_train_batch_end(self, recipe, **kwargs):
        log_data = kwargs['train_log_data']
        # Send metrics to external system
        
recipe = TrainFinetuneRecipeForNextTokenPrediction(
    cfg, 
    callbacks=[MetricsReporter()]
)

Validation

Unit tests: All 9 tests pass ✅

python -m unittest tests.unit_tests.recipes.test_callbacks -v
Ran 9 tests in 0.006s
OK

Example run:

python examples/llm_finetune/finetune_with_callback.py

Verified callbacks execute at correct lifecycle points with custom log prefixes:

2025-11-22 22:36:19 | INFO | __main__ | [SimpleLoggingCallback] 🔥 Training is starting!
2025-11-22 22:36:19 | INFO | __main__ | [SimpleLoggingCallback]    World size: 1 GPUs
2025-11-22 22:36:19 | INFO | __main__ | [SimpleLoggingCallback]    Total steps: 100
2025-11-22 22:37:33 | INFO | __main__ | [SimpleLoggingCallback] 🚀 Step 0/100: Loss = 0.9377, LR = 1.00e-05
...
2025-11-22 22:52:03 | INFO | __main__ | [SimpleLoggingCallback] ✅ Validation 'default': Loss = 0.2208
2025-11-22 22:52:03 | INFO | __main__ | [MetricsCollectorCallback] 📊 Collected 2 validation checkpoints
2025-11-22 22:52:03 | INFO | __main__ | [SimpleLoggingCallback] 💾 Checkpoint saved at step 99, epoch 0, path: checkpoints/epoch_0_step_99
2025-11-22 22:52:03 | INFO | __main__ | [MetricsCollectorCallback] 💾 Tracked checkpoint 2: step=99, train_loss=0.1678
2025-11-22 22:52:03 | INFO | __main__ | [SimpleLoggingCallback] 🎉 Training completed successfully! Final step: 100
2025-11-22 22:52:03 | INFO | __main__ | [MetricsCollectorCallback] 🎉 Training complete! Collected 100 training steps

Closes #805


Note: Happy to implement this callback feature! I designed it to be familiar (PyTorch Lightning-style) and practical for real integrations. Free free to provide feedback on the design/API if you'd like any adjustments. (AI assisted with documentation generation, but I personally reviewed and refined everything to ensure quality and accuracy.)

yuhezhang-ai avatar Nov 23 '25 02:11 yuhezhang-ai

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

copy-pr-bot[bot] avatar Nov 23 '25 02:11 copy-pr-bot[bot]