fix: Megatron static inference and adapt to mcore engine API changes
What does this PR do ?
Fixes some bugs that were present in static inference.
The following were the bugs
- If the backend is megatron the policy_generation is set to None. So had to fix that .
- The mcore update, changed the run_mcore_engine api. The new mcore_engine_api accepts only text prompts.
- Changed the code to directly use the static engine instead of the mcore_engine_api.
Added tests (functional and nightly)
Issues
File "/opt/nemo-rl/nemo_rl/models/policy/megatron_policy_worker.py", line 1839, in generate
result = inference_engine.generate(inference_requests=requests)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/static_engine.py", line 192, in generate
self.run_engine()
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/static_engine.py", line 226, in run_engine
self.controller.generate_all_output_tokens_static_batch(
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/text_generation_controllers/text_generation_controller.py", line 841, in generate_all_output_tokens_static_batch
logits = self.inference_wrapped_model.run_one_forward_step(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 389, in run_one_forward_step
return self.forward_pass_without_pipeline_parallel(inference_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 213, in forward_pass_without_pipeline_parallel
logits = self._forward(inference_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py", line 161, in _forward
return self.model(
^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/data_parallel_base.py", line 22, in forward
return self.module(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/transformer/module.py", line 429, in forward
outputs = self.module(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 441, in forward
preproc_output = self._preprocess(
^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 300, in _preprocess
decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/models/common/embeddings/language_model_embedding.py", line 111, in forward
word_embeddings = self.word_embeddings(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
^^^^^^^
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1816, in inner
args_result = hook(self, args)
^^^^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 435, in hook
self.param_to_bucket_group[param].finish_param_sync(
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 286, in finish_param_sync
self.start_param_sync()
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 242, in start_param_sync
self.cached_param_buffer_shard_list[idx] = shard_buffer(
^^^^^^^^^^^^^
File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 56, in shard_buffer
buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: setStorage: sizes [238883840], strides [1], storage offset 1304830464, and itemsize 2 requiring a storage size of 3087428608 are out of bounds for storage of size 0
Before your PR is "Ready for review"
Pre checks:
- [ ] Make sure you read and followed Contributor guidelines
- [ ] Did you write any new necessary tests?
- [ ] Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
- [ ] Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.
Additional Information
- ...
Summary by CodeRabbit
Release Notes
-
New Features
- Added Megatron-based generation support with batch inference optimization.
-
Refactor
- Improved generation engine pipeline for enhanced performance and stability.
-
Tests
- Added functional test coverage for Megatron generation workflows.
📝 Walkthrough
Walkthrough
Changes introduce batch generation support via a new inference engine flow in MegatronPolicyWorker, add protective None checks in the GRPO algorithm, and establish two new functional test scripts for Megatron-based GRPO generation with validation metrics.
Changes
| Cohort / File(s) | Change Summary |
|---|---|
Algorithm safety check nemo_rl/algorithms/grpo.py |
Adds None guard to policy_generation.prepare_refit_info() call, preventing execution when policy_generation is None. |
Generation engine refactor nemo_rl/models/policy/megatron_policy_worker.py |
Rewrites MegatronPolicyWorker.generate to use batch inference flow: computes tokens_to_generate, pads prompts with EOS tokens, constructs SamplingParams and InferenceRequest objects, invokes inference_engine.generate(), and consolidates results into BatchedDataDict. Adds conditional CUDA movement and new imports (SamplingParams, InferenceRequest). |
Test infrastructure tests/functional/grpo_megatron_generation.sh, tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.sh, tests/test_suites/nightly.txt |
Adds two new functional test scripts for GRPO with Megatron backend (0.5B and 1B model variants) with metric validation, and registers new test in nightly suite. |
Sequence Diagram(s)
sequenceDiagram
participant User
participant MegatronPolicyWorker
participant InferenceEngine
participant SamplingParams as SamplingParams
participant Output as BatchedDataDict
rect rgb(200, 220, 255)
Note over User,Output: New Batch Generation Flow
User->>MegatronPolicyWorker: generate(prompts, max_new_tokens, ...)
opt Move to CUDA if needed
MegatronPolicyWorker->>MegatronPolicyWorker: Move model to CUDA
end
MegatronPolicyWorker->>MegatronPolicyWorker: Pad prompts to length<br/>Create prompt tensors
MegatronPolicyWorker->>SamplingParams: Build with temperature=1.0,<br/>top_k=0, return_log_probs=True
MegatronPolicyWorker->>MegatronPolicyWorker: Create InferenceRequest<br/>batch for each prompt
MegatronPolicyWorker->>InferenceEngine: generate(inference_requests)
InferenceEngine-->>MegatronPolicyWorker: result objects
MegatronPolicyWorker->>Output: Consolidate to BatchedDataDict<br/>(text, tokens, logprobs)
Output-->>User: Return generation output
end
Estimated code review effort
🎯 3 (Moderate) | ⏱️ ~30 minutes
- megatron_policy_worker.py requires careful review of the new inference engine integration, tensor construction, and output consolidation logic
- grpo.py None check is straightforward but context matters; verify policy_generation lifecycle
- Test scripts validate integration but are primarily configuration-driven; focus on metric validation logic in check_metrics.py expectations
Suggested labels
CI:L1, r0.4.0
Suggested reviewers
- terrykong
Pre-merge checks and finishing touches
❌ Failed checks (1 warning)
| Check name | Status | Explanation | Resolution |
|---|---|---|---|
| Test Results For Major Changes | ⚠️ Warning | PR makes major inference pipeline changes but PR description lacks test results, metrics, or regression confirmation. | Add test results, before-and-after metrics, and regression confirmation to PR description to validate the significant inference engine changes. |
✅ Passed checks (3 passed)
| Check name | Status | Explanation |
|---|---|---|
| Docstring Coverage | ✅ Passed | Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%. |
| Description Check | ✅ Passed | Check skipped - CodeRabbit’s high-level summary is enabled. |
| Title check | ✅ Passed | The title clearly and specifically describes the main changes: fixing static inference and adapting to mcore engine API changes, which aligns with the PR objectives and file modifications. |
✨ Finishing touches
- [ ] 📝 Generate docstrings
🧪 Generate unit tests (beta)
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
- [ ] Commit unit tests in branch
fix_static
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
Comment @coderabbitai help to get the list of available commands and usage tips.
@coderabbit title
[!TIP] For best results, initiate chat on the files or code changes.
@shanmugamr1992, based on the changes in this PR, here's a more descriptive title suggestion:
"Fix Megatron static inference and adapt to mcore engine API changes"
Or alternatively:
"Fix static inference bugs in Megatron backend and add tests"
These titles better capture:
- The specific area being fixed (Megatron static inference)
- The core changes (bug fixes and API adaptation)
- The addition of tests
Would you like me to update the PR title to one of these, or would you prefer a different phrasing?