Automodel
Automodel copied to clipboard
feat: callback support #805
Implements PyTorch Lightning-style callback system for Automodel recipes to enable custom integrations (e.g., Customizer metrics reporting, custom monitoring).
Key Features
-
6 lifecycle hooks:
on_train_start,on_train_batch_end,on_validation_end,on_save_checkpoint,on_exception,on_train_end -
@rank_zero_onlydecorator for distributed training convenience -
Full training context passed to callbacks: recipe object,
MetricsSampleobjects, checkpoint info, etc. - Programmatic API: Callbacks passed directly to recipe constructors
Changes
- Added
Callbackbase class andCallbackRunnerinnemo_automodel/components/callbacks/ - Integrated callbacks into all recipe classes (LLM finetune, VLM finetune, sequence classification, knowledge distillation)
- Comprehensive example:
examples/llm_finetune/finetune_with_callback.py - 9 unit tests covering all hooks and distributed behavior
- Documentation:
docs/guides/callbacks.md
Usage
from nemo_automodel.components.callbacks import Callback
class MetricsReporter(Callback):
def on_train_batch_end(self, recipe, **kwargs):
log_data = kwargs['train_log_data']
# Send metrics to external system
recipe = TrainFinetuneRecipeForNextTokenPrediction(
cfg,
callbacks=[MetricsReporter()]
)
Validation
Unit tests: All 9 tests pass ✅
python -m unittest tests.unit_tests.recipes.test_callbacks -v
Ran 9 tests in 0.006s
OK
Example run:
python examples/llm_finetune/finetune_with_callback.py
Verified callbacks execute at correct lifecycle points with custom log prefixes:
2025-11-22 22:36:19 | INFO | __main__ | [SimpleLoggingCallback] 🔥 Training is starting!
2025-11-22 22:36:19 | INFO | __main__ | [SimpleLoggingCallback] World size: 1 GPUs
2025-11-22 22:36:19 | INFO | __main__ | [SimpleLoggingCallback] Total steps: 100
2025-11-22 22:37:33 | INFO | __main__ | [SimpleLoggingCallback] 🚀 Step 0/100: Loss = 0.9377, LR = 1.00e-05
...
2025-11-22 22:52:03 | INFO | __main__ | [SimpleLoggingCallback] ✅ Validation 'default': Loss = 0.2208
2025-11-22 22:52:03 | INFO | __main__ | [MetricsCollectorCallback] 📊 Collected 2 validation checkpoints
2025-11-22 22:52:03 | INFO | __main__ | [SimpleLoggingCallback] 💾 Checkpoint saved at step 99, epoch 0, path: checkpoints/epoch_0_step_99
2025-11-22 22:52:03 | INFO | __main__ | [MetricsCollectorCallback] 💾 Tracked checkpoint 2: step=99, train_loss=0.1678
2025-11-22 22:52:03 | INFO | __main__ | [SimpleLoggingCallback] 🎉 Training completed successfully! Final step: 100
2025-11-22 22:52:03 | INFO | __main__ | [MetricsCollectorCallback] 🎉 Training complete! Collected 100 training steps
Closes #805
Note: Happy to implement this callback feature! I designed it to be familiar (PyTorch Lightning-style) and practical for real integrations. Free free to provide feedback on the design/API if you'd like any adjustments. (AI assisted with documentation generation, but I personally reviewed and refined everything to ensure quality and accuracy.)