RL icon indicating copy to clipboard operation
RL copied to clipboard

fix: Megatron static inference and adapt to mcore engine API changes

Open shanmugamr1992 opened this issue 1 month ago • 3 comments

What does this PR do ?

Fixes some bugs that were present in static inference.

The following were the bugs

  1. If the backend is megatron the policy_generation is set to None. So had to fix that .
  2. The mcore update, changed the run_mcore_engine api. The new mcore_engine_api accepts only text prompts.
  3. Changed the code to directly use the static engine instead of the mcore_engine_api.

Added tests (functional and nightly)

Issues

  File "/opt/nemo-rl/nemo_rl/models/policy/megatron_policy_worker.py", line 1839, in generate
    result = inference_engine.generate(inference_requests=requests)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/static_engine.py", line 192, in generate
    self.run_engine()
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/static_engine.py", line 226, in run_engine
    self.controller.generate_all_output_tokens_static_batch(
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/text_generation_controllers/text_generation_controller.py", line 841, in generate_all_output_tokens_static_batch
    logits = self.inference_wrapped_model.run_one_forward_step(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 389, in run_one_forward_step
    return self.forward_pass_without_pipeline_parallel(inference_input)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 213, in forward_pass_without_pipeline_parallel
    logits = self._forward(inference_input)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 161, in _forward
    return self.model(
           ^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/data_parallel_base.py", line 22, in forward
    return self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/module.py", line 429, in forward
    outputs = self.module(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 441, in forward
    preproc_output = self._preprocess(
                     ^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 300, in _preprocess
    decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/common/embeddings/language_model_embedding.py", line 111, in forward
    word_embeddings = self.word_embeddings(input_ids)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1816, in inner
    args_result = hook(self, args)
                  ^^^^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 435, in hook
    self.param_to_bucket_group[param].finish_param_sync(
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 286, in finish_param_sync
    self.start_param_sync()
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 242, in start_param_sync
    self.cached_param_buffer_shard_list[idx] = shard_buffer(
                                               ^^^^^^^^^^^^^
  File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 56, in shard_buffer
    buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
    ~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: setStorage: sizes [238883840], strides [1], storage offset 1304830464, and itemsize 2 requiring a storage size of 3087428608 are out of bounds for storage of size 0

Before your PR is "Ready for review"

Pre checks:

  • [ ] Make sure you read and followed Contributor guidelines
  • [ ] Did you write any new necessary tests?
  • [ ] Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • [ ] Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Megatron-based generation support with batch inference optimization.
  • Refactor

    • Improved generation engine pipeline for enhanced performance and stability.
  • Tests

    • Added functional test coverage for Megatron generation workflows.

shanmugamr1992 avatar Nov 07 '25 22:11 shanmugamr1992

📝 Walkthrough

Walkthrough

Changes introduce batch generation support via a new inference engine flow in MegatronPolicyWorker, add protective None checks in the GRPO algorithm, and establish two new functional test scripts for Megatron-based GRPO generation with validation metrics.

Changes

Cohort / File(s) Change Summary
Algorithm safety check
nemo_rl/algorithms/grpo.py
Adds None guard to policy_generation.prepare_refit_info() call, preventing execution when policy_generation is None.
Generation engine refactor
nemo_rl/models/policy/megatron_policy_worker.py
Rewrites MegatronPolicyWorker.generate to use batch inference flow: computes tokens_to_generate, pads prompts with EOS tokens, constructs SamplingParams and InferenceRequest objects, invokes inference_engine.generate(), and consolidates results into BatchedDataDict. Adds conditional CUDA movement and new imports (SamplingParams, InferenceRequest).
Test infrastructure
tests/functional/grpo_megatron_generation.sh, tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh, tests/test_suites/nightly.txt
Adds two new functional test scripts for GRPO with Megatron backend (0.5B and 1B model variants) with metric validation, and registers new test in nightly suite.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant MegatronPolicyWorker
    participant InferenceEngine
    participant SamplingParams as SamplingParams
    participant Output as BatchedDataDict
    
    rect rgb(200, 220, 255)
    Note over User,Output: New Batch Generation Flow
    User->>MegatronPolicyWorker: generate(prompts, max_new_tokens, ...)
    
    opt Move to CUDA if needed
        MegatronPolicyWorker->>MegatronPolicyWorker: Move model to CUDA
    end
    
    MegatronPolicyWorker->>MegatronPolicyWorker: Pad prompts to length<br/>Create prompt tensors
    MegatronPolicyWorker->>SamplingParams: Build with temperature=1.0,<br/>top_k=0, return_log_probs=True
    MegatronPolicyWorker->>MegatronPolicyWorker: Create InferenceRequest<br/>batch for each prompt
    MegatronPolicyWorker->>InferenceEngine: generate(inference_requests)
    InferenceEngine-->>MegatronPolicyWorker: result objects
    MegatronPolicyWorker->>Output: Consolidate to BatchedDataDict<br/>(text, tokens, logprobs)
    Output-->>User: Return generation output
    end

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

  • megatron_policy_worker.py requires careful review of the new inference engine integration, tensor construction, and output consolidation logic
  • grpo.py None check is straightforward but context matters; verify policy_generation lifecycle
  • Test scripts validate integration but are primarily configuration-driven; focus on metric validation logic in check_metrics.py expectations

Suggested labels

CI:L1, r0.4.0

Suggested reviewers

  • terrykong

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR makes major inference pipeline changes but PR description lacks test results, metrics, or regression confirmation. Add test results, before-and-after metrics, and regression confirmation to PR description to validate the significant inference engine changes.
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main changes: fixing static inference and adapting to mcore engine API changes, which aligns with the PR objectives and file modifications.
✨ Finishing touches
  • [ ] 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • [ ] Create PR with unit tests
  • [ ] Post copyable unit tests in a comment
  • [ ] Commit unit tests in branch fix_static

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

coderabbitai[bot] avatar Nov 07 '25 22:11 coderabbitai[bot]

@coderabbit title

shanmugamr1992 avatar Nov 07 '25 23:11 shanmugamr1992

[!TIP] For best results, initiate chat on the files or code changes.

@shanmugamr1992, based on the changes in this PR, here's a more descriptive title suggestion:

"Fix Megatron static inference and adapt to mcore engine API changes"

Or alternatively:

"Fix static inference bugs in Megatron backend and add tests"

These titles better capture:

  • The specific area being fixed (Megatron static inference)
  • The core changes (bug fixes and API adaptation)
  • The addition of tests

Would you like me to update the PR title to one of these, or would you prefer a different phrasing?

coderabbitai[bot] avatar Nov 07 '25 23:11 coderabbitai[bot]