-
Notifications
You must be signed in to change notification settings - Fork 172
fix: fix Dtensor sharding error when bump up pytorch version #1557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
|
📝 WalkthroughWalkthroughIntroduces Changes
Sequence Diagram(s)sequenceDiagram
participant Input as Input Activation<br/>(Shard placement)
participant SPAGA as SequenceParallelAllGatherActivation
participant Parent as SequenceParallel._prepare_output_fn
participant Output as Output Activation<br/>(Replicate placement)
Input->>SPAGA: outputs (DTensor with Shard)
activate SPAGA
SPAGA->>SPAGA: Check for Shard placement
alt Has Shard placement
SPAGA->>SPAGA: Redistribute to Replicate
end
SPAGA->>Parent: Call parent _prepare_output_fn
activate Parent
Parent->>Output: Apply parent logic
deactivate Parent
SPAGA->>Output: Return redistributed output
deactivate SPAGA
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/dtensor/parallelize.py (1)
156-166: Update Gemma3 parallelization to match Llama and Qwen fixes.The
_parallelize_gemma3function usesSequenceParallel()for normalization layers (lines 156, 160, 161, 165), while_parallelize_llamaand_parallelize_qwenhave been updated to useSequenceParallelAllGatherActivation()for the same layers. Since Gemma3 has identical layer structure and is architecturally related to Qwen2 (see line 91 comment), it likely encounters the same PyTorch 2.8.0 DTensor sharding issue and should apply the same fix:
- Line 156:
input_layernormshould useSequenceParallelAllGatherActivation(use_local_output=False)- Line 160:
post_attention_layernormshould useSequenceParallelAllGatherActivation(use_local_output=False)
🧹 Nitpick comments (1)
nemo_rl/models/dtensor/parallelize.py (1)
72-87: Good addition, but consider usingTypeErrorfor type validation.The
SequenceParallelAllGatherActivationclass correctly implements all-gather behavior for sharded DTensor activations. However, consider the following improvements:
- Use
TypeErrorinstead ofValueError: When validating types,TypeErroris more semantically appropriate.- Handle edge cases: Consider what happens if outputs is a tuple/list of DTensors or other composite structures.
Apply this diff to use
TypeError:- else: - raise ValueError(f"Expected output to be a DTensor, but got {type(outputs)}") + else: + raise TypeError(f"Expected output to be a DTensor, but got {type(outputs)}")Based on static analysis hints.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
nemo_rl/models/dtensor/parallelize.py(3 hunks)nemo_rl/models/policy/dtensor_policy_worker.py(0 hunks)
💤 Files with no reviewable changes (1)
- nemo_rl/models/policy/dtensor_policy_worker.py
🧰 Additional context used
🪛 Ruff (0.14.5)
nemo_rl/models/dtensor/parallelize.py
84-84: Prefer TypeError exception for invalid type
(TRY004)
84-84: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (2)
nemo_rl/models/dtensor/parallelize.py (2)
264-264: Good addition of fused gate_up_proj support.Adding support for the fused
gate_up_projprojection is a sensible optimization that some Qwen model variants use instead of separategate_projandup_projlayers.
201-203: Approved: Correct application of all-gather for normalization layers.The use of
SequenceParallelAllGatherActivation(use_local_output=False)at lines 201 and 203 ensures proper activation gathering before attention and MLP operations. The pattern is consistent within the Llama configuration.
| "model.layers.*.input_layernorm": SequenceParallelAllGatherActivation(), | ||
| "model.layers.*.self_attn.q_proj": ColwiseParallel(), | ||
| "model.layers.*.self_attn.k_proj": ColwiseParallel(), | ||
| "model.layers.*.self_attn.v_proj": ColwiseParallel(), | ||
| "model.layers.*.self_attn.qkv_proj": ColwiseParallel(), | ||
| "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), | ||
| "model.layers.*.self_attn.q_norm": Qwen3QKNorm(), | ||
| "model.layers.*.self_attn.k_norm": Qwen3QKNorm(), | ||
| "model.layers.*.post_attention_layernorm": SequenceParallel(), | ||
| "model.layers.*.post_attention_layernorm": SequenceParallelAllGatherActivation(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify the omission of use_local_output parameter.
In _parallelize_llama, the SequenceParallelAllGatherActivation is instantiated with use_local_output=False (lines 201, 203), but here in _parallelize_qwen, this parameter is omitted for both input_layernorm (line 253) and post_attention_layernorm (line 261). This inconsistency could lead to different behavior between model architectures.
Please verify whether this difference is intentional or if use_local_output=False should be added for consistency:
🏁 Script executed:
#!/bin/bash
# Check the default value of use_local_output in SequenceParallel
ast-grep --pattern $'class SequenceParallel($$$) {
$$$
def __init__(self, $$$use_local_output$$$) {
$$$
}
$$$
}'Length of output: 0
🏁 Script executed:
#!/bin/bash
# Search for SequenceParallelAllGatherActivation class definition
rg -n "class SequenceParallelAllGatherActivation" --type pythonLength of output: 30
🏁 Script executed:
#!/bin/bash
# Search for SequenceParallelAllGatherActivation usages including __init__ or def
rg -n "SequenceParallelAllGatherActivation" --type python -C 3Length of output: 30
🏁 Script executed:
#!/bin/bash
# Look at the actual code around lines 253-261 in parallelize.py
sed -n '250,265p' nemo_rl/models/dtensor/parallelize.pyLength of output: 1035
🏁 Script executed:
#!/bin/bash
# Look at the actual code around lines 201-203 in parallelize.py
sed -n '198,210p' nemo_rl/models/dtensor/parallelize.pyLength of output: 662
🏁 Script executed:
#!/bin/bash
# Search for SequenceParallelAllGatherActivation class definition without type flag
rg -n "class SequenceParallelAllGatherActivation"Length of output: 101
🏁 Script executed:
#!/bin/bash
# Check imports at the top of parallelize.py
head -50 nemo_rl/models/dtensor/parallelize.py | grep -n "SequenceParallel"Length of output: 24
🏁 Script executed:
#!/bin/bash
# Find where SequenceParallelAllGatherActivation is defined or imported
rg -n "SequenceParallelAllGatherActivation" -A 5 | head -100Length of output: 2759
🏁 Script executed:
#!/bin/bash
# Get the full definition of SequenceParallelAllGatherActivation class
sed -n '72,120p' nemo_rl/models/dtensor/parallelize.pyLength of output: 2246
🏁 Script executed:
#!/bin/bash
# Search for the parent SequenceParallel class or its import source
rg -n "class SequenceParallel" --type-add py:*.py -t pyLength of output: 101
🏁 Script executed:
#!/bin/bash
# Check what SequenceParallel is imported from
head -40 nemo_rl/models/dtensor/parallelize.pyLength of output: 1399
[/web_search]
🌐 Web query:
PyTorch SequenceParallel use_local_output parameter default value
💡 Result:
The default value is False — SequenceParallel(..., use_local_output=False). [1][2]
Sources:
[1] PyTorch docs — torch.distributed.tensor.parallel (SequenceParallel signature).
[2] PyTorch tutorial / docs examples showing use_local_output=False for SequenceParallel.
🏁 Script executed:
#!/bin/bash
# Find all _parallelize functions and check their SequenceParallelAllGatherActivation usage
rg -n "def _parallelize" nemo_rl/models/dtensor/parallelize.pyLength of output: 139
🏁 Script executed:
#!/bin/bash
# Check all SequenceParallelAllGatherActivation usages in the file
rg -n "SequenceParallelAllGatherActivation" nemo_rl/models/dtensor/parallelize.pyLength of output: 473
🏁 Script executed:
#!/bin/bash
# Check _parallelize_gemma3 for SequenceParallelAllGatherActivation usage
sed -n '128,178p' nemo_rl/models/dtensor/parallelize.pyLength of output: 2426
🏁 Script executed:
#!/bin/bash
# Check _parallelize_nm5_h for SequenceParallelAllGatherActivation usage
sed -n '423,492p' nemo_rl/models/dtensor/parallelize.pyLength of output: 2203
Add use_local_output=False parameter for consistency with _parallelize_llama.
Lines 253 and 261 in _parallelize_qwen omit the use_local_output=False parameter in SequenceParallelAllGatherActivation(), while _parallelize_llama (lines 201, 203) explicitly specifies it. Although PyTorch's SequenceParallel defaults to use_local_output=False, making this parameter explicit in Qwen maintains consistency and clarity across model configurations. Update lines 253 and 261 to:
"model.layers.*.input_layernorm": SequenceParallelAllGatherActivation(use_local_output=False),
"model.layers.*.post_attention_layernorm": SequenceParallelAllGatherActivation(use_local_output=False),
🤖 Prompt for AI Agents
In nemo_rl/models/dtensor/parallelize.py around lines 253 to 261, the
SequenceParallelAllGatherActivation calls for "model.layers.*.input_layernorm"
and "model.layers.*.post_attention_layernorm" omit the explicit
use_local_output=False parameter; update those two invocations to include
use_local_output=False to match _parallelize_llama and make behavior explicit
and consistent across configurations.
What does this PR do ?
Successful run after the fix with tp2 sq enabled in qwen model:
https://wandb.ai/nvidia/grpo-dev-zhiyul/runs/nyq6n98w/overview?nw=nwuserzhiyul
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Bug Fixes
✏️ Tip: You can customize this high-level summary in your review settings.