lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Fix `sort_waits` to move `wait` closer to its consumer (#277)

Open kiya00 opened this issue 1 year ago • 11 comments

Before submitting
  • [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • [ ] Did you read the contributor guideline, Pull Request section?
  • [ ] Did you make sure to update the docs?
  • [ ] Did you write any new necessary tests?

What does this PR do?

Fixes #277.

For more details, please see: https://github.com/Lightning-AI/lightning-thunder/issues/277#issuecomment-2100595522

cc @carmocca @awaelchli @crcrpar

kiya00 avatar May 08 '24 15:05 kiya00

Unfortunately this change doesn't help the performance on Llama-2-7b-hf

Run benchmark with torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --micro_batch_size=2 --global_batch_size=16 --model_name=Llama-2-7b-hf --return_metrics_as_json=True --json_path=benchmark_litgpt_datanew.json

Env: H100 80GB * 8, nvfuser: 0.2.3+git729f36c

before this commit(on 7cff363abb814dafc4d1752253f30763cdb0a6f3):

    "average_iter_time": 787.7525961026549,
    "model_flops": 377527625318400,
    "model_flop_per_sec": 3834022026881291.5,
    "tokens_per_sec": 83194.7289094999,
    "memory_used_GB": 65.789196288,
    "model_name": "Llama-2-7b-hf",
    "Num GPUS": 8,
    "Seq Len": 4096,
    "Micro BS": 2,
    "Global BS": 16,
    "GA": 1,
    "Distributed Mode": "fsdp_zero2_none_bucketing",
    "Sharding Size": null,
    "compiler": "thunder"

This PR:

    "average_iter_time": 793.2260634377599,
    "model_flops": 377527625318400,
    "model_flop_per_sec": 3809612930806172.5,
    "tokens_per_sec": 82665.07411965841,
    "memory_used_GB": 65.790769152,
    "model_name": "Llama-2-7b-hf",
    "Num GPUS": 8,
    "Seq Len": 4096,
    "Micro BS": 2,
    "Global BS": 16,
    "GA": 1,
    "Distributed Mode": "fsdp_zero2_none_bucketing",
    "Sharding Size": null,
    "compiler": "thunder"

cc: @IvanYashchuk

kiya00 avatar May 08 '24 15:05 kiya00

Unfortunately this change doesn't help the performance on Llama-2-7b-hf

Is the wait operation now inserted at what seems like the right place to allow computation-communication overlap? What does Nsight Systems profiling tell about the overlap?

IvanYashchuk avatar May 13 '24 10:05 IvanYashchuk

Is the wait operation now inserted at what seems like the right place to allow computation-communication overlap? What does Nsight Systems profiling tell about the overlap?

Before this commit: image This PR: image Although the performance doesn't change, after this PR, the allgathers in stream 22 has overlap with the computation in stream 7.

@IvanYashchuk

kiya00 avatar May 14 '24 13:05 kiya00

Hi @t-vi @carmocca , I think it's ready to merge

kiya00 avatar May 15 '24 15:05 kiya00

Just tried this branch and while I do see some small overlap with the first few layers, majority is still not overlapped and the AllGathers are launched much ahead of compute.

Screenshot 2024-05-15 at 5 31 49 PM

parthmannan avatar May 16 '24 00:05 parthmannan

I think it's expected that all the allgathers are launched in the beginning. To reduce the overhead, rate limiting we do for zero3 could be needed as well for zero2, which https://github.com/Lightning-AI/lightning-thunder/pull/259 is addressing by adding an argument

crcrpar avatar May 16 '24 03:05 crcrpar

Just tried this branch and while I do see some small overlap with the first few layers, majority is still not overlapped and the AllGathers are launched much ahead of compute.

Screenshot 2024-05-15 at 5 31 49 PM

Are stream barriers visible in the profile? They would be a good indicator of whether we're doing the correct thing on the trace level.

IvanYashchuk avatar May 16 '24 09:05 IvanYashchuk

I think it's expected that all the allgathers are launched in the beginning. To reduce the overhead, rate limiting we do for zero3 could be needed as well for zero2, which #259 is addressing by adding an argument

For zero3 rate limiting is needed to limit the peak allocated memory because every call to allgather allocates the output unsharded tensor. We need to limit the number of active allgathers until the result is consumed and freed.

For zero2 we don't care about the memory consumption because all the unsharded tensors are saved for backward. Are there other effects of limiting the number of allgathers besides peak memory allocation?

IvanYashchuk avatar May 16 '24 09:05 IvanYashchuk

In zero2 the forward trace would consists of a sequence of all-gather's followed by another of computations, which would explain the long idling in the compute stream (at least on paper). To shorten the first sequence of all-gather's, the rate limiting should help as it reorder the sequence of bsyms so that it will be a few all-gather's followed by some compute using the those all-gather's outputs.

crcrpar avatar May 16 '24 12:05 crcrpar

For #277, after bisect I found https://github.com/Lightning-AI/lightning-thunder/commit/a76beb6328149b7799765a58ed38a892be39ca97 is the first bad commit. After comparing the trace before/after this commit, I found the order of allgathers changes. Before the commit, the first consumed allgather is always the first to appear in the trace, and so on. So I did some experiment on early transform to specifically reorder the parameters for Llama-2-7b-hf as @t-vi suggested, the allgathers can overlap better with the computation: image The next step will be to rearrange all the allgathers in the order of consumption

Patch specific for Llama-2-7b-hf
diff --git a/thunder/__init__.py b/thunder/__init__.py
index adc6d692..a82d0d1d 100644
--- a/thunder/__init__.py
+++ b/thunder/__init__.py
@@ -301,7 +301,8 @@ def jit(
         interpreter = _general_frontend
 
     if early_transforms is None:
-        early_transforms = []
+        # early_transforms = []
+        early_transforms = [tmp_trans, ]
 
     if additional_transforms is None:
         additional_transforms = []
@@ -938,3 +939,36 @@ def grad(fn):
         return original_result, original_trace
 
     return _fn
+def tmp_trans(prologue_trc, computation_trc, epilogue_trc, executors_list):
+    old_order = ("tos1","sin","transformer_wte_weight","transformer_h_0_norm_1_weight","transformer_h_0_attn_attn_weight","transformer_h_0_attn_proj_weight","transformer_h_0_norm_2_weight","transformer_h_0_mlp_fc_1_weight","transformer_h_0_mlp_fc_2_weight","transformer_h_0_mlp_proj_weight","transformer_h_1_norm_1_weight","transformer_h_1_attn_attn_weight","transformer_h_1_attn_proj_weight","transformer_h_1_norm_2_weight","transformer_h_1_mlp_fc_1_weight","transformer_h_1_mlp_fc_2_weight","transformer_h_1_mlp_proj_weight","transformer_h_2_norm_1_weight","transformer_h_2_attn_attn_weight","transformer_h_2_attn_proj_weight","transformer_h_2_norm_2_weight","transformer_h_2_mlp_fc_1_weight","transformer_h_2_mlp_fc_2_weight","transformer_h_2_mlp_proj_weight","transformer_h_3_norm_1_weight","transformer_h_3_attn_attn_weight","transformer_h_3_attn_proj_weight","transformer_h_3_norm_2_weight","transformer_h_3_mlp_fc_1_weight","transformer_h_3_mlp_fc_2_weight","transformer_h_3_mlp_proj_weight","transformer_h_4_norm_1_weight","transformer_h_4_attn_attn_weight","transformer_h_4_attn_proj_weight","transformer_h_4_norm_2_weight","transformer_h_4_mlp_fc_1_weight","transformer_h_4_mlp_fc_2_weight","transformer_h_4_mlp_proj_weight","transformer_h_5_norm_1_weight","transformer_h_5_attn_attn_weight","transformer_h_5_attn_proj_weight","transformer_h_5_norm_2_weight","transformer_h_5_mlp_fc_1_weight","transformer_h_5_mlp_fc_2_weight","transformer_h_5_mlp_proj_weight","transformer_h_6_norm_1_weight","transformer_h_6_attn_attn_weight","transformer_h_6_attn_proj_weight","transformer_h_6_norm_2_weight","transformer_h_6_mlp_fc_1_weight","transformer_h_6_mlp_fc_2_weight","transformer_h_6_mlp_proj_weight","transformer_h_7_norm_1_weight","transformer_h_7_attn_attn_weight","transformer_h_7_attn_proj_weight","transformer_h_7_norm_2_weight","transformer_h_7_mlp_fc_1_weight","transformer_h_7_mlp_fc_2_weight","transformer_h_7_mlp_proj_weight","transformer_h_8_norm_1_weight","transformer_h_8_attn_attn_weight","transformer_h_8_attn_proj_weight","transformer_h_8_norm_2_weight","transformer_h_8_mlp_fc_1_weight","transformer_h_8_mlp_fc_2_weight","transformer_h_8_mlp_proj_weight","transformer_h_9_norm_1_weight","transformer_h_9_attn_attn_weight","transformer_h_9_attn_proj_weight","transformer_h_9_norm_2_weight","transformer_h_9_mlp_fc_1_weight","transformer_h_9_mlp_fc_2_weight","transformer_h_9_mlp_proj_weight","transformer_h_10_norm_1_weight","transformer_h_10_attn_attn_weight","transformer_h_10_attn_proj_weight","transformer_h_10_norm_2_weight","transformer_h_10_mlp_fc_1_weight","transformer_h_10_mlp_fc_2_weight","transformer_h_10_mlp_proj_weight","transformer_h_11_norm_1_weight","transformer_h_11_attn_attn_weight","transformer_h_11_attn_proj_weight","transformer_h_11_norm_2_weight","transformer_h_11_mlp_fc_1_weight","transformer_h_11_mlp_fc_2_weight","transformer_h_11_mlp_proj_weight","transformer_h_12_norm_1_weight","transformer_h_12_attn_attn_weight","transformer_h_12_attn_proj_weight","transformer_h_12_norm_2_weight","transformer_h_12_mlp_fc_1_weight","transformer_h_12_mlp_fc_2_weight","transformer_h_12_mlp_proj_weight","transformer_h_13_norm_1_weight","transformer_h_13_attn_attn_weight","transformer_h_13_attn_proj_weight","transformer_h_13_norm_2_weight","transformer_h_13_mlp_fc_1_weight","transformer_h_13_mlp_fc_2_weight","transformer_h_13_mlp_proj_weight","transformer_h_14_norm_1_weight","transformer_h_14_attn_attn_weight","transformer_h_14_attn_proj_weight","transformer_h_14_norm_2_weight","transformer_h_14_mlp_fc_1_weight","transformer_h_14_mlp_fc_2_weight","transformer_h_14_mlp_proj_weight","transformer_h_15_norm_1_weight","transformer_h_15_attn_attn_weight","transformer_h_15_attn_proj_weight","transformer_h_15_norm_2_weight","transformer_h_15_mlp_fc_1_weight","transformer_h_15_mlp_fc_2_weight","transformer_h_15_mlp_proj_weight","transformer_h_16_norm_1_weight","transformer_h_16_attn_attn_weight","transformer_h_16_attn_proj_weight","transformer_h_16_norm_2_weight","transformer_h_16_mlp_fc_1_weight","transformer_h_16_mlp_fc_2_weight","transformer_h_16_mlp_proj_weight","transformer_h_17_norm_1_weight","transformer_h_17_attn_attn_weight","transformer_h_17_attn_proj_weight","transformer_h_17_norm_2_weight","transformer_h_17_mlp_fc_1_weight","transformer_h_17_mlp_fc_2_weight","transformer_h_17_mlp_proj_weight","transformer_h_18_norm_1_weight","transformer_h_18_attn_attn_weight","transformer_h_18_attn_proj_weight","transformer_h_18_norm_2_weight","transformer_h_18_mlp_fc_1_weight","transformer_h_18_mlp_fc_2_weight","transformer_h_18_mlp_proj_weight","transformer_h_19_norm_1_weight","transformer_h_19_attn_attn_weight","transformer_h_19_attn_proj_weight","transformer_h_19_norm_2_weight","transformer_h_19_mlp_fc_1_weight","transformer_h_19_mlp_fc_2_weight","transformer_h_19_mlp_proj_weight","transformer_h_20_norm_1_weight","transformer_h_20_attn_attn_weight","transformer_h_20_attn_proj_weight","transformer_h_20_norm_2_weight","transformer_h_20_mlp_fc_1_weight","transformer_h_20_mlp_fc_2_weight","transformer_h_20_mlp_proj_weight","transformer_h_21_norm_1_weight","transformer_h_21_attn_attn_weight","transformer_h_21_attn_proj_weight","transformer_h_21_norm_2_weight","transformer_h_21_mlp_fc_1_weight","transformer_h_21_mlp_fc_2_weight","transformer_h_21_mlp_proj_weight","transformer_h_22_norm_1_weight","transformer_h_22_attn_attn_weight","transformer_h_22_attn_proj_weight","transformer_h_22_norm_2_weight","transformer_h_22_mlp_fc_1_weight","transformer_h_22_mlp_fc_2_weight","transformer_h_22_mlp_proj_weight","transformer_h_23_norm_1_weight","transformer_h_23_attn_attn_weight","transformer_h_23_attn_proj_weight","transformer_h_23_norm_2_weight","transformer_h_23_mlp_fc_1_weight","transformer_h_23_mlp_fc_2_weight","transformer_h_23_mlp_proj_weight","transformer_h_24_norm_1_weight","transformer_h_24_attn_attn_weight","transformer_h_24_attn_proj_weight","transformer_h_24_norm_2_weight","transformer_h_24_mlp_fc_1_weight","transformer_h_24_mlp_fc_2_weight","transformer_h_24_mlp_proj_weight","transformer_h_25_norm_1_weight","transformer_h_25_attn_attn_weight","transformer_h_25_attn_proj_weight","transformer_h_25_norm_2_weight","transformer_h_25_mlp_fc_1_weight","transformer_h_25_mlp_fc_2_weight","transformer_h_25_mlp_proj_weight","transformer_h_26_norm_1_weight","transformer_h_26_attn_attn_weight","transformer_h_26_attn_proj_weight","transformer_h_26_norm_2_weight","transformer_h_26_mlp_fc_1_weight","transformer_h_26_mlp_fc_2_weight","transformer_h_26_mlp_proj_weight","transformer_h_27_norm_1_weight","transformer_h_27_attn_attn_weight","transformer_h_27_attn_proj_weight","transformer_h_27_norm_2_weight","transformer_h_27_mlp_fc_1_weight","transformer_h_27_mlp_fc_2_weight","transformer_h_27_mlp_proj_weight","transformer_h_28_norm_1_weight","transformer_h_28_attn_attn_weight","transformer_h_28_attn_proj_weight","transformer_h_28_norm_2_weight","transformer_h_28_mlp_fc_1_weight","transformer_h_28_mlp_fc_2_weight","transformer_h_28_mlp_proj_weight","transformer_h_29_norm_1_weight","transformer_h_29_attn_attn_weight","transformer_h_29_attn_proj_weight","transformer_h_29_norm_2_weight","transformer_h_29_mlp_fc_1_weight","transformer_h_29_mlp_fc_2_weight","transformer_h_29_mlp_proj_weight","transformer_h_30_norm_1_weight","transformer_h_30_attn_attn_weight","transformer_h_30_attn_proj_weight","transformer_h_30_norm_2_weight","transformer_h_30_mlp_fc_1_weight","transformer_h_30_mlp_fc_2_weight","transformer_h_30_mlp_proj_weight","transformer_h_31_norm_1_weight","transformer_h_31_attn_attn_weight","transformer_h_31_attn_proj_weight","transformer_h_31_norm_2_weight","transformer_h_31_mlp_fc_1_weight","transformer_h_31_mlp_fc_2_weight","transformer_h_31_mlp_proj_weight","transformer_ln_f_weight","lm_head_weight","idx")
+    ret = prologue_trc.bound_symbols[-1]
+    assert(ret.sym.id == prims.PrimIDs.RETURN)
+
+    def sort_func(x):
+        if x.name in old_order:
+            return old_order.index(x.name)
+        assert(x.name[2:] in old_order)
+        return old_order.index(x.name[2:])
+    new_ret_args = tuple(sorted(ret.args[0], key=sort_func))
+    from dataclasses import dataclass, replace
+    new_ret = replace(ret, args=(new_ret_args,))
+    new_ret = replace(new_ret, output=(new_ret_args,))
+    prologue_trc.bound_symbols[-1]=new_ret
+
+    new_args = tuple(sorted(computation_trc.args, key=sort_func))
+    computation_trc.args=new_args
+    new_list = []
+    old_list = computation_trc.bound_symbols[0:len(old_order)]
+
+    for n in old_order:
+        tmp = [p for p in old_list if p.output.name == n or p.output.name == "t_"+n]
+        assert(len(tmp) == 1)
+        new_list.append(tmp[0])
+    assert(len(new_list)==len(old_order))
+    computation_trc.bound_symbols[0:len(old_order)] = new_list
+
+    siginfo_args = []
+    for bsym in new_list:
+        siginfo_args.append((bsym.output.name, None))
+    computation_trc._siginfo.args = siginfo_args
+    return prologue_trc, computation_trc, epilogue_trc
diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index 4b553c2f..816c6f11 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -403,6 +403,11 @@ class Benchmark_litGPT:
             input_ids = input_ids.to(device=self.device)
             targets = targets.to(device=self.device)
             loss = run_fwd_bwd_one_microbatch(self.model, input_ids, targets, self.gradient_accumulation_steps)
+            # if i==0 and global_rank==0:
+            #     # with open('old_trace','w') as f:
+            #     with open('reordertrace','w') as f:
+            #         f.write(str(thunder.last_traces(self.model)[-1]))
+            #         f.write(str(thunder.last_backward_traces(self.model)[-1]))
 
             # Simple Gradient Accumulation Implementation
             self.optimizer.step()

kiya00 avatar May 22 '24 08:05 kiya00

Hi @IvanYashchuk @crcrpar , use the sort_wait_zero3(sort the allgather+wait just before consumer) + unlimited number of inflight allgather(push allgathers to the beginning of the trace) can fix the problem. it sorts the allgathers to their consumer order and list them at the beginning of the trace, the corresponding waits are right before the consumers. nsys results for Llama-2-7b-hf: image

kiya00 avatar May 23 '24 16:05 kiya00

Nice, so the main problem was that the order of allgathers was not following the order they were used in the compute part forcing a long delay before the compute started.

IvanYashchuk avatar May 29 '24 07:05 IvanYashchuk

Before merging let's see what's the impact on perf this latest iteration has.

torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --compile=thunder --distributed_mode=fsdp --micro_batch_size=2 --global_batch_size=16 --model_name=Llama-2-7b-hf On main(b8705922c344a7d08f9ac43ac1b06d2ff7bbaf62)

Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 2
Global BS: 16
Number of Layers: 32
Number of parameters: 0.84B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: none
Compiler: thunder
Average iter time: 780.92 ms
Memory used: 65.79 GB
Tokens/s: 83912.17
Tokens/s/GPU: 10489.02
TFLOP/s: 3867.09

this PR:

Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 2
Global BS: 16
Number of Layers: 32
Number of parameters: 0.84B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: none
Compiler: thunder
Average iter time: 776.61 ms
Memory used: 65.79 GB
Tokens/s: 84361.17
Tokens/s/GPU: 10545.15
TFLOP/s: 3887.78

I think it's ready to merge @IvanYashchuk @t-vi

kiya00 avatar May 29 '24 09:05 kiya00