torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Context Parallel for Qwen3

Open unavailableun opened this issue 1 month ago • 4 comments

Thanks for supporting Qwen3 models!

CP is not supported currently because of RoPE embedding implementation details.

Any plan to support CP + EP for Qwen3 MoE models? If no plan in short time, can you help guide how can I implement it myself?

unavailableun avatar Nov 24 '25 08:11 unavailableun

cc @fegin

tianyu-l avatar Nov 24 '25 18:11 tianyu-l

I saw alert only for CP & FlexAttention, does that mean CP works with SDPA? Image

unavailableun avatar Nov 25 '25 03:11 unavailableun

cc @wwwjn @shuhuayu

tianyu-l avatar Nov 25 '25 04:11 tianyu-l

Thanks for asking! Now Qwen3 is not officially supported for Qwen3, and we haven't implemented and tested CP on Qwen3. This is because of the RoPE embedding differences (In Qwen3, the RoPE field is named "rope_cache" while in train.py, it's looking for freqs_cis field. Qwen3 README: https://github.com/pytorch/torchtitan/tree/refs/heads/main/torchtitan/models/qwen3 Train.py: https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L480

So it won't work with SDPA, but it should be a simple change

wwwjn avatar Nov 25 '25 14:11 wwwjn

What I did to get Qwen3 working with CP:

  1. Added an alias property to the Qwen3 model class for freqs_cis -> rope_cache
    @property
    def freqs_cis(self) -> torch.Tensor:
        """Alias for rope_cache to maintain compatibility with torchtitan context parallelism."""
        return self.rope_cache
  1. Ran into some bugs in the cuDNN backend for SDPA so excluded it from models.attention.ScaledDotProductAttentionWrapper

I also ran into a deeper memory leak bug related to running EP and FSDP (for the MoE variants), where it would cause us to accumulate memory during the backwards pass and OOM.

This was solved by updating the the explicit prefetch lists set in llama4/infra/parallelize.py#L425:

# layer_N should prefetch its OWN experts (the next thing to run)
transformer_block.set_modules_to_backward_prefetch([transformer_block.moe.experts])
# layer_N.experts should prefetch layer_{N-1} (the next layer)
if prev_transformer_block is not None:
      transformer_block.moe.experts.set_modules_to_backward_prefetch(
              [prev_transformer_block]
      )
elif model.tok_embeddings is not None:
      transformer_block.moe.experts.set_modules_to_backward_prefetch(
              [model.tok_embeddings]
      )
else:
      transformer_block.moe.experts.set_modules_to_backward_prefetch([])

This above issue is unrelated but was a blocker. I hope to provide a write up of my understanding of the issue and this solution (and a reprod example) but haven't had the time yet.

pstefa1707 avatar Dec 11 '25 00:12 pstefa1707

@shuhuayu could you please take a look at the memory leak issue mentioned above?

tianyu-l avatar Dec 11 '25 10:12 tianyu-l

I have tried to create a minimal reproducible example but spent an hour with the debug sized model but couldn't get it (I was only facing this getting qwen3-235b training on 2 nodes). I also had some local changes to qwen3 for adding LoRAs, although I don't think they'd impact this.

I appreciate this is not a useful bug report. If this is something anyone else has run into I am keen to spent more time to package this bug up into something that can be iterated on.

Here are the before and after fix memory snapshots. Note the accumulating memory on the backwards pass.

Image Image

Here are my personal debug notes, forgive the vibey-ness. I am happy to hop on a call and discuss if needed. My thesis may very well be wrong.

Summary

A memory leak during the backward pass of training Qwen3-235B-A22B, where ~0.5GB of GPU memory accumulated per layer processed. With 94 layers, this led to OOM well before backward pass completion. The root cause was FSDP2's implicit backward prefetch mechanism fetching weights for layers that had already completed their backward pass, resulting in unsharded tensors that were never freed.


The Architecture

FSDP2 Module Hierarchy

The model has a nested FSDP structure:

Model (FSDPModule)
├── tok_embeddings (FSDPModule)
├── layers.0 (FSDPModule - TransformerBlock)
│   └── moe.experts (FSDPModule - separate mesh!)  ← Key: This is ALSO wrapped
├── layers.1 (FSDPModule - TransformerBlock)
│   └── moe.experts (FSDPModule)
├── ...
├── layers.93 (FSDPModule - TransformerBlock)
│   └── moe.experts (FSDPModule)
└── norm, output (FSDPModule)

Each TransformerBlock AND its inner moe.experts are independently FSDP-wrapped. The experts are on a different mesh (dp_mod_ep_mesh) due to Expert Parallelism.

Two Levels of Prefetching in FSDP2

FSDP2 has two distinct prefetch mechanisms:

  1. Module-level prefetch (explicit API):
    • Controlled via module.set_modules_to_backward_prefetch([...])
    • Stored in _modules_to_backward_prefetch attribute
    • This is what torchtitan's apply_fsdp configures
  2. ParamGroup-level prefetch (implicit/internal):
    • Controlled internally by FSDPParamGroup._backward_prefetch()
    • Called from FSDPParamGroup.pre_backward() hook
    • Not exposed via any public API
    • Uses internal logic to decide what to prefetch

The Bug: Detailed Mechanism

Normal FSDP2 Backward Flow

For a single layer's backward pass, the correct flow is:

1. pre_backward hook fires
2. UNSHARD layer N's weights (all-gather)
3. Compute gradients for layer N
4. post_backward hook fires
5. RESHARD layer N's weights (reduce-scatter gradients, free unsharded weights)
6. Move to layer N-1

What Was Actually Happening

With activation checkpointing (mode = "full") and nested FSDP for MoE experts:

Processing Layer 34's backward:
─────────────────────────────────────────────────────────
1. UNSHARD layers.34 (TransformerBlock wrapper)              [+139MB]
2. UNSHARD layers.34.moe.experts (param group 1)             [+562MB]
3. UNSHARD layers.34.moe.experts (param group 2)             [+562MB]
   ↑ These are for AC forward recomputation
4. RESHARD layers.34.moe.experts (AC forward done)           [-1.1GB]
5. UNSHARD layers.34.moe.experts (for actual backward)       [+562MB]
6. ⚠️ UNSHARD layers.35.moe.experts ← IMPLICIT PREFETCH      [+562MB]
7. POST_BACKWARD layers.34.moe.experts
8. RESHARD layers.34.moe.experts                             [-562MB]
9. POST_BACKWARD layers.34
10. RESHARD layers.34                                        [-139MB]

After layer 34 complete: layers.35.moe.experts STILL UNSHARDED!

Why Layer 35's Experts Were Never Freed

The critical issue at step 6:

  • FSDPParamGroup._backward_prefetch() for layer 34's experts was triggered
  • This internal method decided to prefetch layer 35's experts
  • But layer 35's backward pass had already completed (we're going 93→0)
  • Since layer 35's post_backward hook already ran, there was no subsequent hook to RESHARD those prefetched weights
  • The ~562MB stayed allocated indefinitely

The Accumulation Pattern

Layer 93 backward complete: baseline memory
Layer 92 backward complete: +0.5GB (layer 93 experts prefetched, never freed)
Layer 91 backward complete: +0.5GB (layer 92 experts prefetched, never freed)
...
Layer 34 backward complete: +30GB cumulative
Layer 33 backward complete: +30.5GB cumulative
...
Layer ~27: OOM at ~71GB


Why Explicit Prefetch Settings Didn't Work

What torchtitan's set_modules_to_backward_prefetch([]) Does

This sets the module-level prefetch list to empty:

transformer_block._modules_to_backward_prefetch = []
transformer_block.moe.experts._modules_to_backward_prefetch = []

What It Doesn't Do

It does NOT affect FSDPParamGroup._backward_prefetch(), which is an internal method that:

  1. Is called from the pre_backward hook on the param group
  2. Has its own logic for determining what to prefetch
  3. Appears to prefetch based on module registration order (forward order: 0→93)
  4. Ignores the module-level _modules_to_backward_prefetch setting

The Two Prefetch Systems Don't Talk to Each Other

Module Level (public API):
    FSDPModule.set_modules_to_backward_prefetch([])
    → Sets FSDPModule._modules_to_backward_prefetch = []
    → Affects nothing in your case (was already None)

ParamGroup Level (internal): FSDPParamGroup.pre_backward() → Calls FSDPParamGroup._backward_prefetch() → Uses internal/implicit logic → Ignores module-level settings → THIS IS WHAT WAS CAUSING THE LEAK


Why the Prefetch Was Going in the Wrong Direction

The Root Cause: Module Registration Order

When FSDP2 wraps modules, it records them in forward order:

  • layers.0, layers.1, ..., layers.93

The implicit prefetch in FSDPParamGroup._backward_prefetch() appears to use this registration order to determine prefetch targets, possibly via something like:

# Pseudo-code of what FSDP2 might be doing internally
def _backward_prefetch(self):
    next_module = self._get_next_registered_module()  # Based on forward order!
    if next_module:
        next_module.unshard()  # Prefetch for "next" module

But during backward, we process in reverse order (93→0). So when processing layer 34:

  • FSDP2 thinks "next" is layer 35 (based on forward registration order)
  • But in backward, "next" should be layer 33
  • Layer 35 was already processed, so its post_backward won't run again
  • The prefetched weights are orphaned

The Fix: Disabling Implicit Prefetch

What the Monkey-Patch Does

def _disabled_backward_prefetch(self):
    pass  # Do nothing

fsdp_pg.FSDPParamGroup._backward_prefetch = _disabled_backward_prefetch

This completely bypasses the implicit prefetch logic at the FSDPParamGroup level. Now the backward pass flow becomes:

Processing Layer 34's backward:
─────────────────────────────────────────────────────────
1. UNSHARD layers.34
2. UNSHARD layers.34.moe.experts (for AC recompute)
3. RESHARD layers.34.moe.experts (AC done)
4. UNSHARD layers.34.moe.experts (for backward)
5. POST_BACKWARD layers.34.moe.experts
6. RESHARD layers.34.moe.experts                    ← Properly freed!
7. POST_BACKWARD layers.34
8. RESHARD layers.34                                ← Properly freed!

After layer 34 complete: NO orphaned tensors!

Attached is the memory snapshot for the OOM, has mem-allocation stack traces which may be useful for providing more context. 235b_oom_always_reshard_full_ac.pickle.zip

pstefa1707 avatar Dec 12 '25 18:12 pstefa1707

I tentatively enable CP + SDPA for Qwen3 in https://github.com/pytorch/torchtitan/pull/2144. But I haven't verified the EP + CP part, which we may need some verifications.

fegin avatar Dec 15 '25 23:12 fegin