Fix `sort_waits` to move `wait` closer to its consumer (#277)
Before submitting
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?
What does this PR do?
Fixes #277.
For more details, please see: https://github.com/Lightning-AI/lightning-thunder/issues/277#issuecomment-2100595522
cc @carmocca @awaelchli @crcrpar
Unfortunately this change doesn't help the performance on Llama-2-7b-hf
Run benchmark with torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --micro_batch_size=2 --global_batch_size=16 --model_name=Llama-2-7b-hf --return_metrics_as_json=True --json_path=benchmark_litgpt_datanew.json
Env: H100 80GB * 8, nvfuser: 0.2.3+git729f36c
before this commit(on 7cff363abb814dafc4d1752253f30763cdb0a6f3):
"average_iter_time": 787.7525961026549,
"model_flops": 377527625318400,
"model_flop_per_sec": 3834022026881291.5,
"tokens_per_sec": 83194.7289094999,
"memory_used_GB": 65.789196288,
"model_name": "Llama-2-7b-hf",
"Num GPUS": 8,
"Seq Len": 4096,
"Micro BS": 2,
"Global BS": 16,
"GA": 1,
"Distributed Mode": "fsdp_zero2_none_bucketing",
"Sharding Size": null,
"compiler": "thunder"
This PR:
"average_iter_time": 793.2260634377599,
"model_flops": 377527625318400,
"model_flop_per_sec": 3809612930806172.5,
"tokens_per_sec": 82665.07411965841,
"memory_used_GB": 65.790769152,
"model_name": "Llama-2-7b-hf",
"Num GPUS": 8,
"Seq Len": 4096,
"Micro BS": 2,
"Global BS": 16,
"GA": 1,
"Distributed Mode": "fsdp_zero2_none_bucketing",
"Sharding Size": null,
"compiler": "thunder"
cc: @IvanYashchuk
Unfortunately this change doesn't help the performance on Llama-2-7b-hf
Is the wait operation now inserted at what seems like the right place to allow computation-communication overlap? What does Nsight Systems profiling tell about the overlap?
Is the wait operation now inserted at what seems like the right place to allow computation-communication overlap? What does Nsight Systems profiling tell about the overlap?
Before this commit:
This PR:
Although the performance doesn't change, after this PR, the
allgathers in stream 22 has overlap with the computation in stream 7.
@IvanYashchuk
Hi @t-vi @carmocca , I think it's ready to merge
Just tried this branch and while I do see some small overlap with the first few layers, majority is still not overlapped and the AllGathers are launched much ahead of compute.
I think it's expected that all the allgathers are launched in the beginning. To reduce the overhead, rate limiting we do for zero3 could be needed as well for zero2, which https://github.com/Lightning-AI/lightning-thunder/pull/259 is addressing by adding an argument
Just tried this branch and while I do see some small overlap with the first few layers, majority is still not overlapped and the AllGathers are launched much ahead of compute.
![]()
Are stream barriers visible in the profile? They would be a good indicator of whether we're doing the correct thing on the trace level.
I think it's expected that all the allgathers are launched in the beginning. To reduce the overhead, rate limiting we do for zero3 could be needed as well for zero2, which #259 is addressing by adding an argument
For zero3 rate limiting is needed to limit the peak allocated memory because every call to allgather allocates the output unsharded tensor. We need to limit the number of active allgathers until the result is consumed and freed.
For zero2 we don't care about the memory consumption because all the unsharded tensors are saved for backward. Are there other effects of limiting the number of allgathers besides peak memory allocation?
In zero2 the forward trace would consists of a sequence of all-gather's followed by another of computations, which would explain the long idling in the compute stream (at least on paper). To shorten the first sequence of all-gather's, the rate limiting should help as it reorder the sequence of bsyms so that it will be a few all-gather's followed by some compute using the those all-gather's outputs.
For #277, after bisect I found https://github.com/Lightning-AI/lightning-thunder/commit/a76beb6328149b7799765a58ed38a892be39ca97 is the first bad commit. After comparing the trace before/after this commit, I found the order of allgathers changes. Before the commit, the first consumed allgather is always the first to appear in the trace, and so on. So I did some experiment on early transform to specifically reorder the parameters for Llama-2-7b-hf as @t-vi suggested, the allgathers can overlap better with the computation:
The next step will be to rearrange all the allgathers in the order of consumption
Patch specific for Llama-2-7b-hf
diff --git a/thunder/__init__.py b/thunder/__init__.py
index adc6d692..a82d0d1d 100644
--- a/thunder/__init__.py
+++ b/thunder/__init__.py
@@ -301,7 +301,8 @@ def jit(
interpreter = _general_frontend
if early_transforms is None:
- early_transforms = []
+ # early_transforms = []
+ early_transforms = [tmp_trans, ]
if additional_transforms is None:
additional_transforms = []
@@ -938,3 +939,36 @@ def grad(fn):
return original_result, original_trace
return _fn
+def tmp_trans(prologue_trc, computation_trc, epilogue_trc, executors_list):
+ old_order = ("tos1","sin","transformer_wte_weight","transformer_h_0_norm_1_weight","transformer_h_0_attn_attn_weight","transformer_h_0_attn_proj_weight","transformer_h_0_norm_2_weight","transformer_h_0_mlp_fc_1_weight","transformer_h_0_mlp_fc_2_weight","transformer_h_0_mlp_proj_weight","transformer_h_1_norm_1_weight","transformer_h_1_attn_attn_weight","transformer_h_1_attn_proj_weight","transformer_h_1_norm_2_weight","transformer_h_1_mlp_fc_1_weight","transformer_h_1_mlp_fc_2_weight","transformer_h_1_mlp_proj_weight","transformer_h_2_norm_1_weight","transformer_h_2_attn_attn_weight","transformer_h_2_attn_proj_weight","transformer_h_2_norm_2_weight","transformer_h_2_mlp_fc_1_weight","transformer_h_2_mlp_fc_2_weight","transformer_h_2_mlp_proj_weight","transformer_h_3_norm_1_weight","transformer_h_3_attn_attn_weight","transformer_h_3_attn_proj_weight","transformer_h_3_norm_2_weight","transformer_h_3_mlp_fc_1_weight","transformer_h_3_mlp_fc_2_weight","transformer_h_3_mlp_proj_weight","transformer_h_4_norm_1_weight","transformer_h_4_attn_attn_weight","transformer_h_4_attn_proj_weight","transformer_h_4_norm_2_weight","transformer_h_4_mlp_fc_1_weight","transformer_h_4_mlp_fc_2_weight","transformer_h_4_mlp_proj_weight","transformer_h_5_norm_1_weight","transformer_h_5_attn_attn_weight","transformer_h_5_attn_proj_weight","transformer_h_5_norm_2_weight","transformer_h_5_mlp_fc_1_weight","transformer_h_5_mlp_fc_2_weight","transformer_h_5_mlp_proj_weight","transformer_h_6_norm_1_weight","transformer_h_6_attn_attn_weight","transformer_h_6_attn_proj_weight","transformer_h_6_norm_2_weight","transformer_h_6_mlp_fc_1_weight","transformer_h_6_mlp_fc_2_weight","transformer_h_6_mlp_proj_weight","transformer_h_7_norm_1_weight","transformer_h_7_attn_attn_weight","transformer_h_7_attn_proj_weight","transformer_h_7_norm_2_weight","transformer_h_7_mlp_fc_1_weight","transformer_h_7_mlp_fc_2_weight","transformer_h_7_mlp_proj_weight","transformer_h_8_norm_1_weight","transformer_h_8_attn_attn_weight","transformer_h_8_attn_proj_weight","transformer_h_8_norm_2_weight","transformer_h_8_mlp_fc_1_weight","transformer_h_8_mlp_fc_2_weight","transformer_h_8_mlp_proj_weight","transformer_h_9_norm_1_weight","transformer_h_9_attn_attn_weight","transformer_h_9_attn_proj_weight","transformer_h_9_norm_2_weight","transformer_h_9_mlp_fc_1_weight","transformer_h_9_mlp_fc_2_weight","transformer_h_9_mlp_proj_weight","transformer_h_10_norm_1_weight","transformer_h_10_attn_attn_weight","transformer_h_10_attn_proj_weight","transformer_h_10_norm_2_weight","transformer_h_10_mlp_fc_1_weight","transformer_h_10_mlp_fc_2_weight","transformer_h_10_mlp_proj_weight","transformer_h_11_norm_1_weight","transformer_h_11_attn_attn_weight","transformer_h_11_attn_proj_weight","transformer_h_11_norm_2_weight","transformer_h_11_mlp_fc_1_weight","transformer_h_11_mlp_fc_2_weight","transformer_h_11_mlp_proj_weight","transformer_h_12_norm_1_weight","transformer_h_12_attn_attn_weight","transformer_h_12_attn_proj_weight","transformer_h_12_norm_2_weight","transformer_h_12_mlp_fc_1_weight","transformer_h_12_mlp_fc_2_weight","transformer_h_12_mlp_proj_weight","transformer_h_13_norm_1_weight","transformer_h_13_attn_attn_weight","transformer_h_13_attn_proj_weight","transformer_h_13_norm_2_weight","transformer_h_13_mlp_fc_1_weight","transformer_h_13_mlp_fc_2_weight","transformer_h_13_mlp_proj_weight","transformer_h_14_norm_1_weight","transformer_h_14_attn_attn_weight","transformer_h_14_attn_proj_weight","transformer_h_14_norm_2_weight","transformer_h_14_mlp_fc_1_weight","transformer_h_14_mlp_fc_2_weight","transformer_h_14_mlp_proj_weight","transformer_h_15_norm_1_weight","transformer_h_15_attn_attn_weight","transformer_h_15_attn_proj_weight","transformer_h_15_norm_2_weight","transformer_h_15_mlp_fc_1_weight","transformer_h_15_mlp_fc_2_weight","transformer_h_15_mlp_proj_weight","transformer_h_16_norm_1_weight","transformer_h_16_attn_attn_weight","transformer_h_16_attn_proj_weight","transformer_h_16_norm_2_weight","transformer_h_16_mlp_fc_1_weight","transformer_h_16_mlp_fc_2_weight","transformer_h_16_mlp_proj_weight","transformer_h_17_norm_1_weight","transformer_h_17_attn_attn_weight","transformer_h_17_attn_proj_weight","transformer_h_17_norm_2_weight","transformer_h_17_mlp_fc_1_weight","transformer_h_17_mlp_fc_2_weight","transformer_h_17_mlp_proj_weight","transformer_h_18_norm_1_weight","transformer_h_18_attn_attn_weight","transformer_h_18_attn_proj_weight","transformer_h_18_norm_2_weight","transformer_h_18_mlp_fc_1_weight","transformer_h_18_mlp_fc_2_weight","transformer_h_18_mlp_proj_weight","transformer_h_19_norm_1_weight","transformer_h_19_attn_attn_weight","transformer_h_19_attn_proj_weight","transformer_h_19_norm_2_weight","transformer_h_19_mlp_fc_1_weight","transformer_h_19_mlp_fc_2_weight","transformer_h_19_mlp_proj_weight","transformer_h_20_norm_1_weight","transformer_h_20_attn_attn_weight","transformer_h_20_attn_proj_weight","transformer_h_20_norm_2_weight","transformer_h_20_mlp_fc_1_weight","transformer_h_20_mlp_fc_2_weight","transformer_h_20_mlp_proj_weight","transformer_h_21_norm_1_weight","transformer_h_21_attn_attn_weight","transformer_h_21_attn_proj_weight","transformer_h_21_norm_2_weight","transformer_h_21_mlp_fc_1_weight","transformer_h_21_mlp_fc_2_weight","transformer_h_21_mlp_proj_weight","transformer_h_22_norm_1_weight","transformer_h_22_attn_attn_weight","transformer_h_22_attn_proj_weight","transformer_h_22_norm_2_weight","transformer_h_22_mlp_fc_1_weight","transformer_h_22_mlp_fc_2_weight","transformer_h_22_mlp_proj_weight","transformer_h_23_norm_1_weight","transformer_h_23_attn_attn_weight","transformer_h_23_attn_proj_weight","transformer_h_23_norm_2_weight","transformer_h_23_mlp_fc_1_weight","transformer_h_23_mlp_fc_2_weight","transformer_h_23_mlp_proj_weight","transformer_h_24_norm_1_weight","transformer_h_24_attn_attn_weight","transformer_h_24_attn_proj_weight","transformer_h_24_norm_2_weight","transformer_h_24_mlp_fc_1_weight","transformer_h_24_mlp_fc_2_weight","transformer_h_24_mlp_proj_weight","transformer_h_25_norm_1_weight","transformer_h_25_attn_attn_weight","transformer_h_25_attn_proj_weight","transformer_h_25_norm_2_weight","transformer_h_25_mlp_fc_1_weight","transformer_h_25_mlp_fc_2_weight","transformer_h_25_mlp_proj_weight","transformer_h_26_norm_1_weight","transformer_h_26_attn_attn_weight","transformer_h_26_attn_proj_weight","transformer_h_26_norm_2_weight","transformer_h_26_mlp_fc_1_weight","transformer_h_26_mlp_fc_2_weight","transformer_h_26_mlp_proj_weight","transformer_h_27_norm_1_weight","transformer_h_27_attn_attn_weight","transformer_h_27_attn_proj_weight","transformer_h_27_norm_2_weight","transformer_h_27_mlp_fc_1_weight","transformer_h_27_mlp_fc_2_weight","transformer_h_27_mlp_proj_weight","transformer_h_28_norm_1_weight","transformer_h_28_attn_attn_weight","transformer_h_28_attn_proj_weight","transformer_h_28_norm_2_weight","transformer_h_28_mlp_fc_1_weight","transformer_h_28_mlp_fc_2_weight","transformer_h_28_mlp_proj_weight","transformer_h_29_norm_1_weight","transformer_h_29_attn_attn_weight","transformer_h_29_attn_proj_weight","transformer_h_29_norm_2_weight","transformer_h_29_mlp_fc_1_weight","transformer_h_29_mlp_fc_2_weight","transformer_h_29_mlp_proj_weight","transformer_h_30_norm_1_weight","transformer_h_30_attn_attn_weight","transformer_h_30_attn_proj_weight","transformer_h_30_norm_2_weight","transformer_h_30_mlp_fc_1_weight","transformer_h_30_mlp_fc_2_weight","transformer_h_30_mlp_proj_weight","transformer_h_31_norm_1_weight","transformer_h_31_attn_attn_weight","transformer_h_31_attn_proj_weight","transformer_h_31_norm_2_weight","transformer_h_31_mlp_fc_1_weight","transformer_h_31_mlp_fc_2_weight","transformer_h_31_mlp_proj_weight","transformer_ln_f_weight","lm_head_weight","idx")
+ ret = prologue_trc.bound_symbols[-1]
+ assert(ret.sym.id == prims.PrimIDs.RETURN)
+
+ def sort_func(x):
+ if x.name in old_order:
+ return old_order.index(x.name)
+ assert(x.name[2:] in old_order)
+ return old_order.index(x.name[2:])
+ new_ret_args = tuple(sorted(ret.args[0], key=sort_func))
+ from dataclasses import dataclass, replace
+ new_ret = replace(ret, args=(new_ret_args,))
+ new_ret = replace(new_ret, output=(new_ret_args,))
+ prologue_trc.bound_symbols[-1]=new_ret
+
+ new_args = tuple(sorted(computation_trc.args, key=sort_func))
+ computation_trc.args=new_args
+ new_list = []
+ old_list = computation_trc.bound_symbols[0:len(old_order)]
+
+ for n in old_order:
+ tmp = [p for p in old_list if p.output.name == n or p.output.name == "t_"+n]
+ assert(len(tmp) == 1)
+ new_list.append(tmp[0])
+ assert(len(new_list)==len(old_order))
+ computation_trc.bound_symbols[0:len(old_order)] = new_list
+
+ siginfo_args = []
+ for bsym in new_list:
+ siginfo_args.append((bsym.output.name, None))
+ computation_trc._siginfo.args = siginfo_args
+ return prologue_trc, computation_trc, epilogue_trc
diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index 4b553c2f..816c6f11 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -403,6 +403,11 @@ class Benchmark_litGPT:
input_ids = input_ids.to(device=self.device)
targets = targets.to(device=self.device)
loss = run_fwd_bwd_one_microbatch(self.model, input_ids, targets, self.gradient_accumulation_steps)
+ # if i==0 and global_rank==0:
+ # # with open('old_trace','w') as f:
+ # with open('reordertrace','w') as f:
+ # f.write(str(thunder.last_traces(self.model)[-1]))
+ # f.write(str(thunder.last_backward_traces(self.model)[-1]))
# Simple Gradient Accumulation Implementation
self.optimizer.step()
Hi @IvanYashchuk @crcrpar , use the sort_wait_zero3(sort the allgather+wait just before consumer) + unlimited number of inflight allgather(push allgathers to the beginning of the trace) can fix the problem.
it sorts the allgathers to their consumer order and list them at the beginning of the trace, the corresponding waits are right before the consumers.
nsys results for Llama-2-7b-hf:
Nice, so the main problem was that the order of allgathers was not following the order they were used in the compute part forcing a long delay before the compute started.
Before merging let's see what's the impact on perf this latest iteration has.
torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --micro_batch_size=2 --global_batch_size=16 --model_name=Llama-2-7b-hf
On main(b8705922c344a7d08f9ac43ac1b06d2ff7bbaf62)
Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 2
Global BS: 16
Number of Layers: 32
Number of parameters: 0.84B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: none
Compiler: thunder
Average iter time: 780.92 ms
Memory used: 65.79 GB
Tokens/s: 83912.17
Tokens/s/GPU: 10489.02
TFLOP/s: 3867.09
this PR:
Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 2
Global BS: 16
Number of Layers: 32
Number of parameters: 0.84B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: none
Compiler: thunder
Average iter time: 776.61 ms
Memory used: 65.79 GB
Tokens/s: 84361.17
Tokens/s/GPU: 10545.15
TFLOP/s: 3887.78
I think it's ready to merge @IvanYashchuk @t-vi