The ordering of fsdp, ac, tp, pp and complie etc.
Based on the code, the ordering of parallelization and optimization appears to be: PP → TP → AC → Compile → FSDP/DDP. Is it possible to modify this ordering? If not, could you explain the rationale for this specific sequence?
PP -> SPMD Making PP split first allows us to deal with clean SPMD regions. [maybe not fundamental and can be reversed? @wconstab]
TP -> FSDP because the semantics during training iterations would be the reverse. [natural]
TP -> AC (citing @soulitzer)
Switching the ordering of apply_ac/apply_tp in torchtitan seems to result in a "got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators" somewhere after backward.
[sounds not fundamental]
AC -> torch.compile (citing @soulitzer) Today, compile wrapping AC is more recommended than AC wrapping compile. Compile already does recompute via the min-cut partitioner, so wrapping AC over a compiled region may lead to recomputing multiple times. If compile wraps AC, compile would incorporate the AC region information during the partitioner. It may be possible to improve this behavior though, e.g. detect any ambient AC contexts, and do something different in the partitioner and during runtime. [recommended, sounds fundamental]
torch.compile -> FSDP because torch.compile would graph break on FSDP2 wrapped modules. [fundamental for pytorch FSDP2]
we probably should document this somewhere @wwwjn
@tianyu-l Would it be okay if I take this up? If so, I can create the documentation here: torchtitan/docs/parallelism_ordering.md
# Parallelism Ordering and Composition Guide
## Overview
This guide explains the correct order for applying different parallelism techniques in TorchTitan.
## Recommended Order
1. Activation Checkpointing (AC)
2. Tensor Parallelism (TP)
3. Pipeline Parallelism (PP)
4. torch.compile
5. FSDP/Data Parallelism
## Why This Order Matters
[Explanation]
Due to the introduction of MoE / EP, support of the orderings have to be expanded. E.g.
- I believe now we have FSDP(AC) working https://github.com/pytorch/pytorch/pull/164009
- SAC(compile) is WIP https://github.com/pytorch/pytorch/issues/161889
As there is active work on this, I'd recommend we hold off a bit and leave it informal for now.
AC -> torch.compile (citing @soulitzer) Today, compile wrapping AC is more recommended than AC wrapping compile. Compile already does recompute via the min-cut partitioner, so wrapping AC over a compiled region may lead to recomputing multiple times. If compile wraps AC, compile would incorporate the AC region information during the partitioner. It may be possible to improve this behavior though, e.g. detect any ambient AC contexts, and do something different in the partitioner and during runtime. [recommended, sounds fundamental]
https://github.com/pytorch/pytorch/blob/6fa2cf39d66dff43681fa443a31fd3385ca967d4/torch/utils/checkpoint.py#L348 https://github.com/pytorch/pytorch/blob/6fa2cf39d66dff43681fa443a31fd3385ca967d4/torch/utils/checkpoint.py#L1173 I think compile wrapping AC would disable compile. @tianyu-l
@984881878 Today we are wrapping compile over AC.
@soulitzer could you comment?