RL icon indicating copy to clipboard operation
RL copied to clipboard

fix: Use Float16Module even when defer_fp32_logits=True

Open yfw opened this issue 1 month ago β€’ 1 comments

What does this PR do ?

Previously, we were skipping the Float16Module when defer_fp32_logits=True. This PR changes the logic so that we still use the Float16Module when defer_fp32_logits=True, but only skip the final cast to fp32.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • [ ] Make sure you read and followed Contributor guidelines
  • [ ] Did you write any new necessary tests?
  • [ ] Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • [ ] Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Added optional parameter to control FP32 logits deferral during model inference, enabling more granular control over mixed-precision optimization.
  • Improvements

    • Enhanced compatibility with FP16/BF16 configurations through unified propagation of logits handling settings across training pipelines, including forward passes for training, logit probability generation, and sampling operations.

✏️ Tip: You can customize this high-level summary in your review settings.

yfw avatar Nov 18 '25 17:11 yfw

πŸ“ Walkthrough

Walkthrough

A new optional parameter defer_fp32_logits is added to the forward step function in the common module and integrated throughout the policy worker initialization and inference paths. This parameter controls whether FP32 logits output is deferred, allowing models to output logits in lower precision formats when appropriate. Docstrings are updated to document the new parameter and related padding options.

Changes

Cohort / File(s) Summary
Common Forward Step Enhancement
nemo_rl/models/megatron/common.py
Added defer_fp32_logits: Optional[bool] = None parameter to forward_step_arbitrary_loss method signature. When truthy, sets additional_kwargs["fp32_output"] = False before model invocation. Updated docstrings to describe the new parameter and existing padding-related parameters.
Policy Worker Integration
nemo_rl/models/policy/megatron_policy_worker.py
Introduced self.defer_fp32_logits attribute in MegatronPolicyWorker initialization, derived from config and current FP16/BF16 mode. Removed unconditional disabling of mixed-precision wrapper based on defer_fp32_logits. Propagated defer_fp32_logits parameter to forward_step calls used in train, logprob generation, and top-k logits inference paths. Applied similar changes to reference model mixed-precision wrapper handling.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

  • Parameter passing is consistent and straightforward across both files
  • Changes follow a clear pattern: parameter added to signature, threaded through initialization, and passed to downstream calls
  • Updates are localized and do not affect core logic
  • Primary focus should be verifying parameter propagation is complete and correct through all inference paths (train, logprobs, top-k_logits)

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
βœ… Passed checks (3 passed)
Check name Status Explanation
Description Check βœ… Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check βœ… Passed The title 'fix: Use Float16Module even when defer_fp32_logits=True' directly addresses the main change: ensuring Float16Module is used while deferring FP32 logits output, which is the core objective of the PR.
Test Results For Major Changes βœ… Passed Comprehensive tests for defer_fp32_logits functionality are integrated into the codebase across multiple test files and scenarios.
✨ Finishing touches
  • [ ] πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • [ ] Create PR with unit tests
  • [ ] Post copyable unit tests in a comment
  • [ ] Commit unit tests in branch yifu/defer_fp32_logits

[!TIP]

πŸ“ Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests β€” including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. πŸ“ Description β€” Summarize the main change in 50–60 words, explaining what was done.
  2. πŸ““ References β€” List relevant issues, discussions, documentation, or related PRs.
  3. πŸ“¦ Dependencies & Requirements β€” Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. πŸ“Š Contributor Summary β€” Include a Markdown table showing contributions: | Contributor | Lines Added | Lines Removed | Files Changed |
  5. βœ”οΈ Additional Notes β€” Add any extra reviewer context. Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❀️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

coderabbitai[bot] avatar Nov 25 '25 00:11 coderabbitai[bot]