transformers icon indicating copy to clipboard operation
transformers copied to clipboard

FSDP2 training hangs during backward pass with MoE models when some experts are not activated

Open LucienXian opened this issue 2 months ago • 6 comments

System Info

Environment

  • transformers: 4.53.2
  • torch: 2.7.1+cu128
  • Model: Qwen3-30B-A3B

Who can help?

@seven-mile @ArthurZucker

Information

  • [ ] The official example scripts
  • [x] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [x] My own task or dataset (give details below)

Reproduction

Minimal Reproduction Code:

import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling

def train():
    # Initialize distributed training
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    
    # Load MoE model (Qwen3 MoE specifically)
    model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen3-30B-A3B",  # Note: This should be a MoE model path
        trust_remote_code=True
    ).cuda()
    
    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen3-30B-A3B",  # Same as model path
        trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Apply FSDP2 - THIS CAUSES THE HANG WITH MOE MODELS
    use_fsdp2 = True
    if use_fsdp2:
        from transformers.fsdp import fully_shard, MixedPrecisionPolicy
        mp_policy = MixedPrecisionPolicy(
            param_dtype=torch.bfloat16, 
            reduce_dtype=torch.float32, 
            cast_forward_inputs=True
        )
        for layer in model.model.layers:
            fully_shard(layer, mp_policy=mp_policy)
        fully_shard(model, mp_policy=mp_policy)
    
    # which can lead to some experts not being activated on certain ranks
    train_dataset = getDataset()
    sampler = DistributedSampler(train_dataset, shuffle=False)
    # WORKAROUND: This configuration prevents the hang (forces same samples per DP rank)
    # sampler = DistributedSampler(train_dataset, rank=0, num_replicas=1, shuffle=False)
    
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=1024,
        num_workers=1,
        drop_last=True,
        pin_memory=True,
        collate_fn=data_collator,
        sampler=sampler,
    )
    
    # Training setup
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # Training loop that hangs during backward
    model.train()
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
        
        outputs = model(**batch)
        loss = outputs.loss / 2  # Gradient accumulation steps
        loss.backward()  # HANGS HERE with FSDP2 when some experts aren't activated
        
        if step % 2 == 1:  # Gradient accumulation
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        if rank == 0:
            print(f"Step {step} completed successfully")

if __name__ == "__main__":
    train()

Expected behavior

training will not get stuck

Description

We've encountered a critical issue when training Qwen3 MoE models with FSDP2 in transformers 4.53.2. The training process hangs during the backward pass under specific conditions.

Key observations:

  1. Training hangs during backward pass when using FSDP2 with MoE models Image

  2. The issue does not occur when:

  • We force all DP ranks to receive identical samples (by setting DistributedSampler(rank=0, num_replicas=1))
  • We revert PR #38133

Suspected cause

We suspect PR #38133 introduced a change that causes FSDP2 to hang when some experts in MoE layers are not activated during forward pass. When certain experts receive no tokens (zero activation), FSDP2's gradient synchronization mechanism appears to deadlock during backward pass.

Questions for Maintainers cc @seven-mile @ArthurZucker

  1. Is this a known FSDP2 + MoE limitation?
  2. Should FSDP2 handle unused experts gracefully? (Or is this a bug?) Would a PR to modify expert masking or adjust FSDP2 sync logic help?

Offer to Help

I’m happy to test fixes or collaborate on a PR with guidance! Let me know how I can assist!

LucienXian avatar Oct 27 '25 05:10 LucienXian

Thank you for the bug report, cc @ArthurZucker for MoEs and @sunmarc for Trainer as well!

Rocketknight1 avatar Oct 28 '25 13:10 Rocketknight1

The same problem occurred when training an InternVL3.5 model using FSDP2 (with Qwen3MOE as the backbone network). Different DP rank router to different experts, and not all experts were covered, leading to mismatch shapes and communication failures during gradient reduce. Is there any update on this issue? cc @ArthurZucker @SunMarc

pjgao avatar Nov 05 '25 01:11 pjgao

cc @ArthurZucker

LucienXian avatar Nov 10 '25 17:11 LucienXian

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Dec 05 '25 08:12 github-actions[bot]

cc @3outeille, maybe you faced something similar as you are trying to add support for moe in torchtitan

SunMarc avatar Dec 08 '25 13:12 SunMarc

interesting, I havent encounter this issue when training with FSDP and Qwen3-Moe backbone. Will take a look but we need a distributed training CI earlier than expected

3outeille avatar Dec 08 '25 13:12 3outeille

referencing https://github.com/huggingface/transformers/pull/42765

3outeille avatar Dec 10 '25 11:12 3outeille

@LucienXian @@ArthurZucker @SunMarc

I encountered the same issue while using LLaMa-Factory. Below are the details regarding the distributed training setup (launched via Accelerate):

Environment & Symptoms

  • FSDP1 / DeepSpeed ZeRO-3: Training works fine.
  • FSDP2 / DeepSpeed ZeRO-2: Training hangs.
  • Diagnosis: Using pystack on the remote PID reveals that the process hangs specifically during clip_grad operations.

Solution

The issue appears to be related to how sparse MoE handles active vs. inactive parameters during gradient synchronization. To resolve this, I manually modified the expert calculation code to force all experts to participate, rather than only calculating the selected top-k experts.

I encountered this behavior across Qwen3-Omni, Qwen3-30B-MoE, and GLM4.6V.

Code Modification (Example: Qwen3-Omni)

I changed the implementation from the original sparse routing:

class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
    def __init__(self, config: Qwen3OmniMoeThinkerConfig):
        super().__init__()
        self.experts = Qwen3OmniMoeThinkerTextExperts(config)
        self.gate = Qwen3OmniMoeThinkerTextTopKRouter(config)

    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
        _, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
        final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
        return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

To a dense execution where all experts are calculated, and non-selected experts are masked with zero weights:

class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob

        # gating
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
        self.experts = nn.ModuleList(
            [
                modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextMLP(
                    config, intermediate_size=config.moe_intermediate_size
                )
                for _ in range(self.num_experts)
            ]
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)

        # Calculate the routing weights for all experts
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

        # Retain the weight of the top_k and reset the rest of the expert rights to 0 (instead of retaining only top_k experts)
        top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
        # Initialize the all-zero weight matrix (same shape as all experts)
        full_routing_weights = torch.zeros_like(routing_weights)
        # Only the weight of top_k experts is retained, and the weight of the rest of the experts remains at 0
        full_routing_weights.scatter_(1, top_k_indices, top_k_weights)

        # Normalized top_k weights (keep the original logic consistent)
        if self.norm_topk_prob:
            # Calculate the sum of the weights top_k each row (for normalization)
            top_k_sum = full_routing_weights.sum(dim=-1, keepdim=True)
            # Avoid dividing by zero
            top_k_sum = torch.clamp(top_k_sum, min=1e-9)
            full_routing_weights /= top_k_sum

        # Convert back to the input data type
        full_routing_weights = full_routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # Go through all the experts (not just the selected ones)
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            # Get the weight of the current expert (inactive expert has a weight of 0 here)
            expert_weights = full_routing_weights[:, expert_idx, None]  # shape: (batch*seq, 1)
            # All samples participate in the calculations of the current expert, the weight may be equal to 0
            current_hidden_states = expert_layer(hidden_states) * expert_weights
            # Add-up to all expert outputs (experts with a weight of 0 do not affect the result)
            final_hidden_states += current_hidden_states

        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

jiaqiw09 avatar Dec 16 '25 07:12 jiaqiw09