fix: fix Dtensor sharding error when bump up pytorch version
What does this PR do ?
Successful run after the fix with tp2 sq enabled in qwen model: https://wandb.ai/nvidia/grpo-dev-zhiyul/runs/nyq6n98w/overview?nw=nwuserzhiyul
Issues
List issues that this PR closes (syntax):
Usage
- You can potentially add a usage example below
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
- [ ] Make sure you read and followed Contributor guidelines
- [ ] Did you write any new necessary tests?
- [ ] Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
- [ ] Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.
Additional Information
- ...
Summary by CodeRabbit
-
New Features
- Enabled sequence parallelism to work in combination with tensor parallelism for improved distributed training performance.
- Enhanced parallelization strategies for Llama, Qwen, and Gemma3 models with optimized activation handling in parallel configurations.
-
Bug Fixes
- Removed runtime error that previously blocked sequence parallelism when using tensor parallelism, enabling more flexible model parallelization options.
✏️ Tip: You can customize this high-level summary in your review settings.
⚠️ File Consistency Check
Check based on commit: 8c58b54fa54296ea9edd104c24afcad86dc7fdc1 (PR #1557 from zhiyul/fix_sharding_qwen)
⚠️ Parallel Plans Synchronization Warning
The file nemo_rl/models/dtensor/parallelize.py was modified in this PR, but neither 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py nor 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py was updated.
Why this matters: These files contain similar parallel plan implementations that should be kept synchronized to ensure consistency across the codebase.
Action required:
- Please review if the changes in
nemo_rl/models/dtensor/parallelize.pyshould also be applied to3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.pyor3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py - Update the appropriate related file(s) if necessary to maintain functional consistency
- Request access to the NVIDIA-NeMo/Automodel repository, create a PR against the
nemo-rl-submodulebranch, and update the Automodel submodule in the nemo-rl index - Add @ffrujeri as a reviewer of this PR if you have any questions about the consistency requirements
- If the files are intentionally different, please add a comment in the PR explaining why
Files to check:
- Modified:
nemo_rl/models/dtensor/parallelize.py - Not modified:
3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py - Not modified:
3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py
⚠️ DTensor Policy Worker Synchronization Warning
The file nemo_rl/models/policy/dtensor_policy_worker.py was modified in this PR, but nemo_rl/models/policy/dtensor_policy_worker_v2.py was not updated.
Why this matters: These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.
Action required:
- Please review if the changes in
nemo_rl/models/policy/dtensor_policy_worker.pyshould also be applied tonemo_rl/models/policy/dtensor_policy_worker_v2.py - Update
nemo_rl/models/policy/dtensor_policy_worker_v2.pyif necessary to maintain consistency - If the files are intentionally different, please add a comment in the PR explaining why
Files to check:
- Modified:
nemo_rl/models/policy/dtensor_policy_worker.py - Not modified:
nemo_rl/models/policy/dtensor_policy_worker_v2.py
This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.
📝 Walkthrough
Walkthrough
Introduces SequenceParallelAllGatherActivation class to redistribute sequence-parallel DTensor activations from Shard to Replicate placements across normalization and attention layers in multiple model architectures. Also removes the runtime restriction preventing sequence parallelism when tensor parallelism is enabled.
Changes
| Cohort / File(s) | Change Summary |
|---|---|
SequenceParallelAllGatherActivation implementation nemo_rl/models/dtensor/parallelize.py |
Adds new SequenceParallelAllGatherActivation class extending SequenceParallel with a _prepare_output_fn static method that redistributes DTensors with Shard placements to Replicate before delegating to parent class. |
Model parallelization strategy updates nemo_rl/models/dtensor/parallelize.py |
Updates model-parallel mappings for Llama, Qwen, and Gemma3 architectures to use SequenceParallelAllGatherActivation for input_layernorm and post_attention_layernorm layers; adds gate_up_proj projections to parallelization plans; configures use_local_output=False where applicable. |
Tensor parallelism restriction removal nemo_rl/models/policy/dtensor_policy_worker.py |
Removes runtime error that blocked sequence parallelism when tp_size > 1; retains warning for tp_size == 1 case. |
Sequence Diagram(s)
sequenceDiagram
participant Input as Input Activation<br/>(Shard placement)
participant SPAGA as SequenceParallelAllGatherActivation
participant Parent as SequenceParallel._prepare_output_fn
participant Output as Output Activation<br/>(Replicate placement)
Input->>SPAGA: outputs (DTensor with Shard)
activate SPAGA
SPAGA->>SPAGA: Check for Shard placement
alt Has Shard placement
SPAGA->>SPAGA: Redistribute to Replicate
end
SPAGA->>Parent: Call parent _prepare_output_fn
activate Parent
Parent->>Output: Apply parent logic
deactivate Parent
SPAGA->>Output: Return redistributed output
deactivate SPAGA
Estimated code review effort
🎯 3 (Moderate) | ⏱️ ~20 minutes
- SequenceParallelAllGatherActivation class logic: Verify DTensor redistribution from Shard to Replicate is correct and compatible with parent class delegation
- Model architecture consistency: Ensure Llama, Qwen, and Gemma3 updates follow the same pattern for normalization and projection layers
- Gate_up_proj addition: Confirm the new
model.layers.*.mlp.gate_up_proj(ColwiseParallel) entry is correctly integrated - tp_size restriction removal: Validate that removing the
tp_size > 1guard doesn't violate other assumptions in state loading or model initialization
Suggested labels
r0.4.0
Suggested reviewers
- yaoyu-33
- terrykong
Pre-merge checks and finishing touches
❌ Failed checks (1 warning)
| Check name | Status | Explanation | Resolution |
|---|---|---|---|
| Test Results For Major Changes | ⚠️ Warning | PR introduces major distributed training changes without test results, performance benchmarks, convergence validation, or CI/CD documentation. | Add comprehensive testing information including numerical correctness validation, performance benchmarks, distributed training verification, and CI/CD test run references. |
✅ Passed checks (3 passed)
| Check name | Status | Explanation |
|---|---|---|
| Description Check | ✅ Passed | Check skipped - CodeRabbit’s high-level summary is enabled. |
| Title check | ✅ Passed | The title describes fixing a DTensor sharding error related to PyTorch version bumping, which is partially related to the changes. However, the actual changes involve introducing SequenceParallelAllGatherActivation and removing sequence-parallel restrictions with tensor parallelism, which is broader than just fixing a sharding error. |
| Docstring Coverage | ✅ Passed | No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check. |
✨ Finishing touches
- [ ] 📝 Generate docstrings
🧪 Generate unit tests (beta)
- [ ] Create PR with unit tests
- [ ] Post copyable unit tests in a comment
- [ ] Commit unit tests in branch
zhiyul/fix_sharding_qwen
[!TIP]
📝 Customizable high-level summaries are now available in beta!
You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
- Provide your own instructions using the
high_level_summary_instructionssetting.- Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
- Use
high_level_summary_in_walkthroughto move the summary from the description to the walkthrough section.Example instruction:
"Divide the high-level summary into five sections:
- 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
- 📓 References — List relevant issues, discussions, documentation, or related PRs.
- 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
- 📊 Contributor Summary — Include a Markdown table showing contributions:
| Contributor | Lines Added | Lines Removed | Files Changed |- ✔️ Additional Notes — Add any extra reviewer context. Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.
Comment @coderabbitai help to get the list of available commands and usage tips.