torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

The ordering of fsdp, ac, tp, pp and complie etc.

Open aoyulong opened this issue 5 months ago • 4 comments

Based on the code, the ordering of parallelization and optimization appears to be: PP → TP → AC → Compile → FSDP/DDP. Is it possible to modify this ordering? If not, could you explain the rationale for this specific sequence?

aoyulong avatar Aug 12 '25 04:08 aoyulong

PP -> SPMD Making PP split first allows us to deal with clean SPMD regions. [maybe not fundamental and can be reversed? @wconstab]

TP -> FSDP because the semantics during training iterations would be the reverse. [natural]

TP -> AC (citing @soulitzer) Switching the ordering of apply_ac/apply_tp in torchtitan seems to result in a "got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators" somewhere after backward. [sounds not fundamental]

AC -> torch.compile (citing @soulitzer) Today, compile wrapping AC is more recommended than AC wrapping compile. Compile already does recompute via the min-cut partitioner, so wrapping AC over a compiled region may lead to recomputing multiple times. If compile wraps AC, compile would incorporate the AC region information during the partitioner. It may be possible to improve this behavior though, e.g. detect any ambient AC contexts, and do something different in the partitioner and during runtime. [recommended, sounds fundamental]

torch.compile -> FSDP because torch.compile would graph break on FSDP2 wrapped modules. [fundamental for pytorch FSDP2]

tianyu-l avatar Aug 13 '25 07:08 tianyu-l

we probably should document this somewhere @wwwjn

tianyu-l avatar Aug 13 '25 07:08 tianyu-l

@tianyu-l Would it be okay if I take this up? If so, I can create the documentation here: torchtitan/docs/parallelism_ordering.md

# Parallelism Ordering and Composition Guide

## Overview
This guide explains the correct order for applying different parallelism techniques in TorchTitan.

## Recommended Order

1. Activation Checkpointing (AC)
2. Tensor Parallelism (TP)  
3. Pipeline Parallelism (PP)
4. torch.compile
5. FSDP/Data Parallelism

## Why This Order Matters
[Explanation]

krishnakalyan3 avatar Oct 07 '25 22:10 krishnakalyan3

Due to the introduction of MoE / EP, support of the orderings have to be expanded. E.g.

  • I believe now we have FSDP(AC) working https://github.com/pytorch/pytorch/pull/164009
  • SAC(compile) is WIP https://github.com/pytorch/pytorch/issues/161889

As there is active work on this, I'd recommend we hold off a bit and leave it informal for now.

tianyu-l avatar Oct 08 '25 21:10 tianyu-l

AC -> torch.compile (citing @soulitzer) Today, compile wrapping AC is more recommended than AC wrapping compile. Compile already does recompute via the min-cut partitioner, so wrapping AC over a compiled region may lead to recomputing multiple times. If compile wraps AC, compile would incorporate the AC region information during the partitioner. It may be possible to improve this behavior though, e.g. detect any ambient AC contexts, and do something different in the partitioner and during runtime. [recommended, sounds fundamental]

https://github.com/pytorch/pytorch/blob/6fa2cf39d66dff43681fa443a31fd3385ca967d4/torch/utils/checkpoint.py#L348 https://github.com/pytorch/pytorch/blob/6fa2cf39d66dff43681fa443a31fd3385ca967d4/torch/utils/checkpoint.py#L1173 I think compile wrapping AC would disable compile. @tianyu-l

984881878 avatar Dec 12 '25 07:12 984881878

@984881878 Today we are wrapping compile over AC.

@soulitzer could you comment?

tianyu-l avatar Dec 12 '25 10:12 tianyu-l