FSDP2 & Thunder looks memory hungrier than `thunder.distributed.fsdp` for certain models
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
Let's take stablecode-completion-alpha-3b whose sequence length (Config.block_size) is 16384,
torchrun --standalone --role rank --tee 3 --local-ranks-filter 0 --nproc-per-node 8 thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --distributed_mode fsdp2 --shard_mode zero2 --compile thunder_inductor_cat_cudnn_dynamo
This command goes OOM while the same config (= model and compile) works with a single H100 with the memory usage of 77.02 GB.
For the sequence length of 16384
| FSDP Impl | Thunder | Torch Compile | Diff |
|---|---|---|---|
| thunder fsdp | 67.74 | #N/A | #N/A |
| FSDP2 | OOM | 62.7 | #N/A |
| FSDP1 | #N/A | 62.97 | #N/A |
For the sequence length of 8192
| FSDP Impl | Thunder | Torch Compile | Diff |
|---|---|---|---|
| thunder fsdp | 37.58 | #N/A | #N/A |
| FSDP2 | 56.69 | 35.21 | 21.48 |
| FSDP1 | #N/A | 35.24 | #N/A |
When --distributed_mode is "fsdp", then the benchmark script chooses thunder.distributed.fsdp for --compile of thunder w/o dynamo keyword, and FSDP1 for the others.
Clearly, FSDP2 & Thunder uses too much memory even compared to thunder's fsdp, while thunder's fsdp itslef seems to use more memory than Eager and Torch Compile. When I was on #940, I didn't see this trend of memory usage. Also, for Llama-3-8B, thunder still uses more memory but the gap is not that huge.
| FSDP Impl | Thunder | Torch Compile | Diff |
|---|---|---|---|
| thunder fsdp | 75.79 | #N/A | #N/A |
| FSDP2 | 74.84 | 72.61 | 2.23 |
| FSDP1 | #N/A | 73.41 | #N/A |
To Reproduce
Apply a diff like this and run commands like
torchrun --standalone --role rank --tee 3 --local-ranks-filter 0 thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --warmup_iters 0 --max_iters 3 --compile eager --dump_memory_snapshot false --block_size 2048
@@ -227,6 +269,7 @@ class Benchmark_litGPT:
fsdp_bucket_params: float | None = None,
checkpoint_activations: bool = False,
n_layers: int | None = None,
+ block_size: int | None = None,
profiler_start: int = 15,
profiler_stop: int = 15,
skip_data_sync: bool = False,
@@ -360,6 +403,8 @@ class Benchmark_litGPT:
if n_layers is not None:
self.config.n_layer = n_layers
+ if block_size is not None:
+ self.config.block_size = block_size
# Initialize the model
t0 = time.perf_counter()
Code sample
Expected behavior
Environment
pjnl-20240919
Additional context
related to #1175
cc @carmocca @crcrpar
LayerNorm vs RMSNorm GptNeoxMLP vs LlamaMLP
fwiw, the mlp used in stablecode is not benchmarked as per #742
# stablecode-completion-alpha-3b
GPT(
(lm_head): Linear(in_features=2560, out_features=49152, bias=False)
(transformer): ModuleDict(
(wte): Embedding(49152, 2560)
(h): ModuleList(
(0): Block(
(norm_1): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(attn): CausalSelfAttention(
(attn): Linear(in_features=2560, out_features=7680, bias=True)
(proj): Linear(in_features=2560, out_features=2560, bias=True)
)
(post_attention_norm): Identity()
(norm_2): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(mlp): GptNeoxMLP(
(fc): Linear(in_features=2560, out_features=10240, bias=True)
(proj): Linear(in_features=10240, out_features=2560, bias=True)
)
(post_mlp_norm): Identity()
)
)
(ln_f): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
)
)
# Llama-3-8B
GPT(
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
(transformer): ModuleDict(
(wte): Embedding(128256, 4096)
(h): ModuleList(
(0): Block(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(attn): Linear(in_features=4096, out_features=6144, bias=False)
(proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(post_attention_norm): Identity()
(norm_2): RMSNorm()
(mlp): LLaMAMLP(
(fc_1): Linear(in_features=4096, out_features=14336, bias=False)
(fc_2): Linear(in_features=4096, out_features=14336, bias=False)
(proj): Linear(in_features=14336, out_features=4096, bias=False)
)
(post_mlp_norm): Identity()
)
)
(ln_f): RMSNorm()
)
)
This also seems to boil down to parallel_residual
If I manually change turn parallel_residual off, then the memory consumption of fsdp2 & ThunderCompiler for stablecode-completion-alpha-3b is smaller than thunder.distributed.fsdp, and even it matches that of torch.compile
Hi! Please let me know when we will be ready to check FSDP 2 again :)