FSDP2 training hangs during backward pass with MoE models when some experts are not activated
System Info
Environment
- transformers: 4.53.2
- torch: 2.7.1+cu128
- Model: Qwen3-30B-A3B
Who can help?
@seven-mile @ArthurZucker
Information
- [ ] The official example scripts
- [x] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [x] My own task or dataset (give details below)
Reproduction
Minimal Reproduction Code:
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
def train():
# Initialize distributed training
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
torch.cuda.set_device(rank)
# Load MoE model (Qwen3 MoE specifically)
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-30B-A3B", # Note: This should be a MoE model path
trust_remote_code=True
).cuda()
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B", # Same as model path
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Apply FSDP2 - THIS CAUSES THE HANG WITH MOE MODELS
use_fsdp2 = True
if use_fsdp2:
from transformers.fsdp import fully_shard, MixedPrecisionPolicy
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
cast_forward_inputs=True
)
for layer in model.model.layers:
fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)
# which can lead to some experts not being activated on certain ranks
train_dataset = getDataset()
sampler = DistributedSampler(train_dataset, shuffle=False)
# WORKAROUND: This configuration prevents the hang (forces same samples per DP rank)
# sampler = DistributedSampler(train_dataset, rank=0, num_replicas=1, shuffle=False)
train_dataloader = DataLoader(
train_dataset,
batch_size=1024,
num_workers=1,
drop_last=True,
pin_memory=True,
collate_fn=data_collator,
sampler=sampler,
)
# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Training loop that hangs during backward
model.train()
for step, batch in enumerate(train_dataloader):
batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss / 2 # Gradient accumulation steps
loss.backward() # HANGS HERE with FSDP2 when some experts aren't activated
if step % 2 == 1: # Gradient accumulation
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
if rank == 0:
print(f"Step {step} completed successfully")
if __name__ == "__main__":
train()
Expected behavior
training will not get stuck
Description
We've encountered a critical issue when training Qwen3 MoE models with FSDP2 in transformers 4.53.2. The training process hangs during the backward pass under specific conditions.
Key observations:
-
Training hangs during backward pass when using FSDP2 with MoE models
-
The issue does not occur when:
- We force all DP ranks to receive identical samples (by setting DistributedSampler(rank=0, num_replicas=1))
- We revert PR #38133
Suspected cause
We suspect PR #38133 introduced a change that causes FSDP2 to hang when some experts in MoE layers are not activated during forward pass. When certain experts receive no tokens (zero activation), FSDP2's gradient synchronization mechanism appears to deadlock during backward pass.
Questions for Maintainers cc @seven-mile @ArthurZucker
- Is this a known FSDP2 + MoE limitation?
- Should FSDP2 handle unused experts gracefully? (Or is this a bug?) Would a PR to modify expert masking or adjust FSDP2 sync logic help?
Offer to Help
I’m happy to test fixes or collaborate on a PR with guidance! Let me know how I can assist!
Thank you for the bug report, cc @ArthurZucker for MoEs and @sunmarc for Trainer as well!
The same problem occurred when training an InternVL3.5 model using FSDP2 (with Qwen3MOE as the backbone network). Different DP rank router to different experts, and not all experts were covered, leading to mismatch shapes and communication failures during gradient reduce. Is there any update on this issue? cc @ArthurZucker @SunMarc
cc @ArthurZucker
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
cc @3outeille, maybe you faced something similar as you are trying to add support for moe in torchtitan
interesting, I havent encounter this issue when training with FSDP and Qwen3-Moe backbone. Will take a look but we need a distributed training CI earlier than expected
referencing https://github.com/huggingface/transformers/pull/42765
@LucienXian @@ArthurZucker @SunMarc
I encountered the same issue while using LLaMa-Factory. Below are the details regarding the distributed training setup (launched via Accelerate):
Environment & Symptoms
- FSDP1 / DeepSpeed ZeRO-3: Training works fine.
- FSDP2 / DeepSpeed ZeRO-2: Training hangs.
- Diagnosis: Using
pystackon the remote PID reveals that the process hangs specifically duringclip_gradoperations.
Solution
The issue appears to be related to how sparse MoE handles active vs. inactive parameters during gradient synchronization. To resolve this, I manually modified the expert calculation code to force all experts to participate, rather than only calculating the selected top-k experts.
I encountered this behavior across Qwen3-Omni, Qwen3-30B-MoE, and GLM4.6V.
Code Modification (Example: Qwen3-Omni)
I changed the implementation from the original sparse routing:
class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
def __init__(self, config: Qwen3OmniMoeThinkerConfig):
super().__init__()
self.experts = Qwen3OmniMoeThinkerTextExperts(config)
self.gate = Qwen3OmniMoeThinkerTextTopKRouter(config)
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
To a dense execution where all experts are calculated, and non-selected experts are masked with zero weights:
class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
# gating
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = nn.ModuleList(
[
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextMLP(
config, intermediate_size=config.moe_intermediate_size
)
for _ in range(self.num_experts)
]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
# Calculate the routing weights for all experts
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
# Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts)
top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
# Initialize the all-zero weight matrix (same shape as all experts)
full_routing_weights = torch.zeros_like(routing_weights)
# Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0
full_routing_weights.scatter_(1, top_k_indices, top_k_weights)
# Normalized top_k weights (keep the original logic consistent)
if self.norm_topk_prob:
# Calculate the sum of the weights top_k each row (for normalization)
top_k_sum = full_routing_weights.sum(dim=-1, keepdim=True)
# Avoid dividing by zero
top_k_sum = torch.clamp(top_k_sum, min=1e-9)
full_routing_weights /= top_k_sum
# Convert back to the input data type
full_routing_weights = full_routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# Go through all the experts (not just the selected ones)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
# Get the weight of the current expert (inactive expert has a weight of 0 here)
expert_weights = full_routing_weights[:, expert_idx, None] # shape: (batch*seq, 1)
# All samples participate in the calculations of the current expert, the weight may be equal to 0
current_hidden_states = expert_layer(hidden_states) * expert_weights
# Add-up to all expert outputs (experts with a weight of 0 do not affect the result)
final_hidden_states += current_hidden_states
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits