fix: Use Float16Module even when defer_fp32_logits=True
What does this PR do ?
Previously, we were skipping the Float16Module when defer_fp32_logits=True. This PR changes the logic so that we still use the Float16Module when defer_fp32_logits=True, but only skip the final cast to fp32.
Issues
List issues that this PR closes (syntax):
Usage
- You can potentially add a usage example below
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
- [ ] Make sure you read and followed Contributor guidelines
- [ ] Did you write any new necessary tests?
- [ ] Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
- [ ] Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.
Additional Information
- ...
Summary by CodeRabbit
-
New Features
- Added optional parameter to control FP32 logits deferral during model inference, enabling more granular control over mixed-precision optimization.
-
Improvements
- Enhanced compatibility with FP16/BF16 configurations through unified propagation of logits handling settings across training pipelines, including forward passes for training, logit probability generation, and sampling operations.
βοΈ Tip: You can customize this high-level summary in your review settings.
π Walkthrough
Walkthrough
A new optional parameter defer_fp32_logits is added to the forward step function in the common module and integrated throughout the policy worker initialization and inference paths. This parameter controls whether FP32 logits output is deferred, allowing models to output logits in lower precision formats when appropriate. Docstrings are updated to document the new parameter and related padding options.
Changes
| Cohort / File(s) | Summary |
|---|---|
Common Forward Step Enhancement nemo_rl/models/megatron/common.py |
Added defer_fp32_logits: Optional[bool] = None parameter to forward_step_arbitrary_loss method signature. When truthy, sets additional_kwargs["fp32_output"] = False before model invocation. Updated docstrings to describe the new parameter and existing padding-related parameters. |
Policy Worker Integration nemo_rl/models/policy/megatron_policy_worker.py |
Introduced self.defer_fp32_logits attribute in MegatronPolicyWorker initialization, derived from config and current FP16/BF16 mode. Removed unconditional disabling of mixed-precision wrapper based on defer_fp32_logits. Propagated defer_fp32_logits parameter to forward_step calls used in train, logprob generation, and top-k logits inference paths. Applied similar changes to reference model mixed-precision wrapper handling. |
Estimated code review effort
π― 2 (Simple) | β±οΈ ~12 minutes
- Parameter passing is consistent and straightforward across both files
- Changes follow a clear pattern: parameter added to signature, threaded through initialization, and passed to downstream calls
- Updates are localized and do not affect core logic
- Primary focus should be verifying parameter propagation is complete and correct through all inference paths (train, logprobs, top-k_logits)
Pre-merge checks and finishing touches
β Failed checks (1 warning)
| Check name | Status | Explanation | Resolution |
|---|---|---|---|
| Docstring Coverage | β οΈ Warning | Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. | You can run @coderabbitai generate docstrings to improve docstring coverage. |
β Passed checks (3 passed)
| Check name | Status | Explanation |
|---|---|---|
| Description Check | β Passed | Check skipped - CodeRabbitβs high-level summary is enabled. |
| Title check | β Passed | The title 'fix: Use Float16Module even when defer_fp32_logits=True' directly addresses the main change: ensuring Float16Module is used while deferring FP32 logits output, which is the core objective of the PR. |
| Test Results For Major Changes | β Passed | Comprehensive tests for defer_fp32_logits functionality are integrated into the codebase across multiple test files and scenarios. |
β¨ Finishing touches
- [ ] π Generate docstrings
π§ͺ Generate unit tests (beta)
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
- [ ] Commit unit tests in branch
yifu/defer_fp32_logits
[!TIP]
π Customizable high-level summaries are now available in beta!
You can now customize how CodeRabbit generates the high-level summary in your pull requests β including its content, structure, tone, and formatting.
- Provide your own instructions using the
high_level_summary_instructionssetting.- Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
- Use
high_level_summary_in_walkthroughto move the summary from the description to the walkthrough section.Example instruction:
"Divide the high-level summary into five sections:
- π Description β Summarize the main change in 50β60 words, explaining what was done.
- π References β List relevant issues, discussions, documentation, or related PRs.
- π¦ Dependencies & Requirements β Mention any new/updated dependencies, environment variable changes, or configuration updates.
- π Contributor Summary β Include a Markdown table showing contributions:
| Contributor | Lines Added | Lines Removed | Files Changed |- βοΈ Additional Notes β Add any extra reviewer context. Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
Comment @coderabbitai help to get the list of available commands and usage tips.