lightning-thunder
lightning-thunder copied to clipboard
Double first iteration time for `Llama-2-7b-hf` after nvFuser direct bindings
🐛 Bug
As per title, the following command takes double the time for the first iteration at commit 8b542cf16e12cf04f6375691621bb3adb0c4acea :
python thunder/benchmarks/benchmark_litgpt.py --model_name Llama-2-7b-hf --compile thunder_inductor_cat --checkpoint_activations False --low_precision_mode none --micro_batch_size 1 --global_batch_size 64 --use_sdpa False --block_size 4096 --max_iters 10 --warmup_iters 5
compared to before #2502 :
# Before
iter 0: loss 0.1650, iter time: 105627.25ms, t: 4096
# After
iter 0: loss 0.1650, iter time: 215099.65ms, t: 4096
Tested on the latest container on B200
cc @rdspring1 @kshitij12345 @crcrpar
Direct bindings doesn't use an LRU cache to cache FusionDefinition creation. It was proposed in [RFC] Create LRU cache for direct bindings. #4893 but shelved until it is needed. From Thunder's perspective, you simply wrap create_fd in thunder/executors/nvfuserex_impl.py with LruFusionCache.
Codediff to add LruFusionCache
diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index b65d6944..5531cd9d 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -77,6 +77,7 @@ if nvfuser_version() >= DIRECT_BINDINGS_SUPPORTED_VERSION:
from nvfuser_direct import (
DataType,
FusionDefinition,
+ LruFusionCache,
multidevice,
ParallelType,
execute_with_dtensors,
@@ -297,6 +298,15 @@ def multidevice_schedule(fd: FusionDefinition, in_dtensors: list[Proxy]) -> None
in_tv.set_allocation_domain(in_tv.get_loop_domain(), new_contiguity=True)
+# This function wraps nvfuser_direct's LruFusionCache with a version check.
+def FusionCacheDecorator(func: callable):
+ # For legacy bindings, the decorator does nothing.
+ if nvfuser_version() < DIRECT_BINDINGS_SUPPORTED_VERSION:
+ return func
+ from nvfuser_direct import LruFusionCache
+ return LruFusionCache(max_fusions=16384)(func)
+
+@FusionCacheDecorator
def create_fd(
bsyms: list[BoundSymbol],
input_descriptors: Sequence[type | tuple[tuple[int, ...], tuple[bool, ...], tuple[int, ...]]],
Timing for direct bindings with LRU Cache --- (7% faster than legacy bindings)
Time to instantiate model: 0.10 seconds.
iter 0: loss 0.1650, iter time: 111987.17ms, t: 4096
iter 1: loss 0.1226, iter time: 14075.67ms, t: 4096
iter 2: loss 0.0771, iter time: 13722.17ms, t: 4096
iter 3: loss 0.0752, iter time: 13571.74ms, t: 4096
iter 4: loss 0.0747, iter time: 13525.45ms, t: 4096
iter 5: loss 0.0732, iter time: 13525.61ms, t: 4096
iter 6: loss 0.0732, iter time: 13499.64ms, t: 4096
iter 7: loss 0.0728, iter time: 13510.84ms, t: 4096
iter 8: loss 0.0728, iter time: 13483.25ms, t: 4096
iter 9: loss 0.0728, iter time: 13469.61ms, t: 4096
Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 1
Global BS: 64
Number of Layers: 32
Number of parameters: 6.74B
Distributed Mode: none
Compiler: thunder_inductor_cat
Low Precision Mode: none using low precision is not enabled
Average iter time: 13497.84 ms
Memory used: 75.52 GB
Saved for backward size: 30157.55 MiB
Saved for backward number of tensors: 712
Tokens/s: 19431.17
Tokens/s/GPU: 19431.17
TFLOP/s: 13.99
Timing for direct bindings without LRU cache --- (2x slowdown first iteration):
iter 0: loss 0.1650, iter time: 194757.66ms, t: 4096
iter 1: loss 0.1226, iter time: 14117.95ms, t: 4096
iter 2: loss 0.0771, iter time: 13743.62ms, t: 4096
iter 3: loss 0.0752, iter time: 13621.90ms, t: 4096
iter 4: loss 0.0747, iter time: 13587.59ms, t: 4096
iter 5: loss 0.0732, iter time: 13541.79ms, t: 4096
iter 6: loss 0.0732, iter time: 13512.29ms, t: 4096
iter 7: loss 0.0728, iter time: 13541.85ms, t: 4096
iter 8: loss 0.0728, iter time: 13533.82ms, t: 4096
iter 9: loss 0.0728, iter time: 13468.19ms, t: 4096
Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 1
Global BS: 64
Number of Layers: 32
Number of parameters: 6.74B
Distributed Mode: none
Compiler: thunder_inductor_cat
Low Precision Mode: none using low precision is not enabled
Average iter time: 13519.64 ms
Memory used: 75.52 GB
Saved for backward size: 30157.55 MiB
Saved for backward number of tensors: 712
Tokens/s: 19397.81
Tokens/s/GPU: 19397.81
TFLOP/s: 13.97
Timing for legacy bindings:
iter 0: loss 0.1650, iter time: 119407.35ms, t: 4096
iter 1: loss 0.1226, iter time: 14073.24ms, t: 4096
iter 2: loss 0.0771, iter time: 13733.47ms, t: 4096
iter 3: loss 0.0752, iter time: 13638.33ms, t: 4096
iter 4: loss 0.0747, iter time: 13588.40ms, t: 4096
iter 5: loss 0.0732, iter time: 13531.92ms, t: 4096
iter 6: loss 0.0732, iter time: 13513.27ms, t: 4096
iter 7: loss 0.0728, iter time: 13531.05ms, t: 4096
iter 8: loss 0.0728, iter time: 13533.12ms, t: 4096
iter 9: loss 0.0728, iter time: 13509.20ms, t: 4096
Model name: Llama-2-7b-hf
Seq Length: 4096
Micro BS: 1
Global BS: 64
Number of Layers: 32
Number of parameters: 6.74B
Distributed Mode: none
Compiler: thunder_inductor_cat
Low Precision Mode: none
Average iter time: 13523.76 ms
Memory used: 75.52 GB
Saved for backward size: 30157.55 MiB
Saved for backward number of tensors: 712
Tokens/s: 19386.89
Tokens/s/GPU: 19386.89
TFLOP/s: 13.96
Is there a cheaper solution?
https://github.com/NVIDIA/Fuser/pull/4893 is a lot of code and maintenance.
Even with it, the first iteration taking 105 seconds is still painful.
Can Thunder avoid compiling the same trace multiple times?
I suspect the real issue might be similar to that benchmark_inference.py lacks symbolic value.
From Thunder's perspective, you simply wrap create_fd in thunder/executors/nvfuserex_impl.py with LruFusionCache.
The above patch seems good to me. I was just wondering if we still need the lru_cache in the snippet below
https://github.com/Lightning-AI/lightning-thunder/blob/f8648aa98d4262d9988527973960ddc74a628214/thunder/executors/nvfuserex_impl.py#L618-L624
Can Thunder avoid compiling the same trace multiple times? I suspect the real issue might be similar to that benchmark_inference.py lacks symbolic value.
benchmark_litgpt uses litgpt models which are friendly to thunder, so we are able to capture the full-graph. Also, I checked if we were recompiling the model due to changing input metadata by adding this patch.
diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index 02262d15..ac2978cb 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -855,6 +855,8 @@ class Benchmark_litGPT:
print(
f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}"
)
+ print("cache hits:", thunder.cache_hits(self.model))
+ print("cache misses:", thunder.cache_misses(self.model))
if i >= self.warmup_iters:
if self.throughput:
self.throughput.update(
Over different iterations it consistently prints only cache misses: 1 - corresponding to the first compilation (see code below).
https://github.com/Lightning-AI/lightning-thunder/blob/f8648aa98d4262d9988527973960ddc74a628214/thunder/init.py#L768-L774
So, it seems to me that the first iteration time comes mainly from the time it takes to compile single trace of the whole model once.
However, I agree that 100 seconds is still painful. For reference, torch.compile with default mode takes around 45 seconds.
the first iteration time comes mainly from the time it takes to compile single trace of the whole model once.
Great to know that! So it's likely a single Thunder trace that contains many identical nvFusions. I guess these nvFusions come from an unrolled for each transformer layer? This is indeed different from https://github.com/Lightning-AI/lightning-thunder/issues/2687.
Yes, having a fusion cache in nvFuser makes sense.
Can you show me the Thunder trace? I'm worried that this solution is not sufficient for the longer term. When nvFuser accepts more ops and is able to fuse across transformer layers, it'll see an unrolled nvFusion that still takes >200 seconds to compile. How can we solve that?
About a year ago I looked a bit at not recompiling identifying identical subgraphs in the context of cudagraphs:
- if the bsym order is stable, renaming the proxies as in https://github.com/Lightning-AI/lightning-thunder/blob/f8648aa98d4262d9988527973960ddc74a628214/thunder/core/transform_common.py#L449 will help.
- it is not clear to me if the bsym order is stable when you do the "dataflow" fusion thing.
The latter might also lead to "non-trivially isomorphic" fusions, not sure if that is what you're hitting.
Thanks for the context!
IIUC, we could cache FusionDefinition in Thunder as well. https://github.com/Lightning-AI/lightning-thunder/blob/f8648aa98d4262d9988527973960ddc74a628214/thunder/executors/nvfuserex_impl.py#L622 caches FusionDefinition for the same input descriptors passed into FusionDefinitionWrapper. However, FusionDefinitionWrapper itself can be cached as well.