feat: Add moe load balancing metrics
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
- You can potentially add a usage example below
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
- [ ] Make sure you read and followed Contributor guidelines
- [ ] Did you write any new necessary tests?
- [ ] Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
- [ ] Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.
Additional Information
- ...
Summary by CodeRabbit
Release Notes
-
New Features
- Added MOE (Mixture of Experts) metrics tracking to training pipelines
- Per-layer metrics logging available for detailed MOE diagnostics
-
Configuration
- Introduced
track_moe_metricsandmoe_per_layer_loggingflags across training configurations (disabled by default)
- Introduced
-
Tests
- Added comprehensive unit tests for MOE metrics collection and aggregation
✏️ Tip: You can customize this high-level summary in your review settings.
✅ Submodule Fast-Forward Check Results
Check based on commit: b9f870f0cd25a0e82b789e1385ee5aa3aa837132 (PR #1520 from yifu/moe_metrics_main)
✅ Submodules that are properly updated:
Megatron-Bridge: ✅ PR branch is ahead of main branch (fast-forward)
All submodule changes look good! ✨
📝 Walkthrough
Walkthrough
This PR adds Mixture of Experts (MOE) metrics tracking and per-layer logging capabilities to the NeMo RL training pipeline. Configuration flags are introduced to enable optional MOE metrics collection, a new function computes and aggregates MOE auxiliary losses, and training algorithms are updated to propagate these metrics through their results.
Changes
| Cohort / File(s) | Summary |
|---|---|
Configuration files examples/configs/dpo.yaml, examples/configs/grpo_math_1B.yaml, examples/configs/sft.yaml, examples/configs/vlm_grpo_3B.yaml, examples/configs/vlm_grpo_3B_megatron.yaml, examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml |
Added two new boolean configuration flags under megatron_cfg: track_moe_metrics and moe_per_layer_logging, both defaulting to False |
Algorithm metric aggregation nemo_rl/algorithms/dpo.py, nemo_rl/algorithms/grpo.py, nemo_rl/algorithms/sft.py |
Updated to conditionally include moe_metrics from train_results in the training metrics dictionary with "moe/" prefix |
MOE metrics computation nemo_rl/models/megatron/common.py |
Added new get_moe_metrics() function that computes, scales, and aggregates MOE auxiliary losses with optional per-layer logging |
Policy configuration schema nemo_rl/models/policy/__init__.py |
Extended MegatronConfig TypedDict with track_moe_metrics and moe_per_layer_logging boolean fields |
Policy training integration nemo_rl/models/policy/lm_policy.py, nemo_rl/models/policy/megatron_policy_worker.py |
Added conditional MOE metrics propagation from worker results and integration of metrics computation in the training loop using the new get_moe_metrics function |
MOE metrics tests tests/unit/models/megatron/test_moe_metrics.py |
Added unit tests covering get_moe_metrics behavior with empty trackers and aggregation with per-layer logging |
Estimated code review effort
🎯 3 (Moderate) | ⏱️ ~20 minutes
- Key areas requiring attention:
nemo_rl/models/megatron/common.py— Verify loss scaling logic and averaging across MOE layers is correctnemo_rl/models/policy/megatron_policy_worker.py— Confirm proper initialization and increment oftotal_num_microbatchesand correct loss scale calculation (1/max(1, total_num_microbatches))tests/unit/models/megatron/test_moe_metrics.py— Ensure test coverage accurately reflects expected aggregation behavior and per-layer logging format
Possibly related PRs
- NVIDIA-NeMo/RL#1034 — Also modifies
nemo_rl/models/policy/lm_policy.py's Policy.train method; coordinate to avoid conflicts in metrics aggregation
Suggested labels
CI:L1
Suggested reviewers
- terrykong
- chtruong814
- parthchadha
- yaoyu-33
Pre-merge checks and finishing touches
❌ Failed checks (2 warnings)
| Check name | Status | Explanation | Resolution |
|---|---|---|---|
| Docstring Coverage | ⚠️ Warning | Docstring coverage is 61.54% which is insufficient. The required threshold is 80.00%. | You can run @coderabbitai generate docstrings to improve docstring coverage. |
| Test Results For Major Changes | ⚠️ Warning | PR introduces major feature affecting 3 algorithms across 11+ files, but PR description lacks all testing information, convergence validation, performance analysis, and concrete details—it is essentially a template with placeholders. | Update PR description to include: unit test results, integration testing evidence across DPO/GRPO/SFT, convergence validation showing no regression, performance impact analysis, example metric outputs, and add requested docstrings for configuration flags. |
✅ Passed checks (2 passed)
| Check name | Status | Explanation |
|---|---|---|
| Title check | ✅ Passed | The title 'feat: Add moe load balancing metrics' accurately reflects the main changes: introducing MOE metrics tracking with configuration flags and new metric extraction logic across multiple algorithm files. |
| Description Check | ✅ Passed | Check skipped - CodeRabbit’s high-level summary is enabled. |
✨ Finishing touches
- [ ] 📝 Generate docstrings
🧪 Generate unit tests (beta)
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
- [ ] Commit unit tests in branch
yifu/moe_metrics_main
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
Comment @coderabbitai help to get the list of available commands and usage tips.