Perf is not great on HF Transformers Llama 3.2 1B
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
Perf is not great, on H100. Using the quickstart example.
Transformers full options: 482.82ms
Transformers overhead: 649.67ms
transformers plain eager: 1197.29ms
Thunder: 1775.08ms
To reproduce
import torch
import transformers
import transformers.generation
import thunder
import thunder.recipes
import thunder.plugins
from thunder.dev_utils.benchmark import benchmark_n
model_name = "meta-llama/Llama-3.2-1B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
with torch.device(device):
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.requires_grad_(False)
model.eval()
# apparently, Transformers 4.51.3 does not instantiate models on the default device
model.to(device)
inp = tokenizer(["Hello world! Here's a long story"], return_tensors="pt")
def generate(model, inp, transformers_compile='full'):
genconf = model._prepare_generation_config(None)[0]
if transformers_compile == 'none': # if you enable this, no torch compile
genconf.disable_compile = True
elif transformers_compile == 'overhead':
# if you enable the line below (but don't disable compile above) you get torch.compile with "default" mode rather than reduce overhead
genconf.compile_config = transformers.generation.CompileConfig(mode="default")
elif transformers_compile != 'full':
raise NotImplementedError(f"unsupported {transformers_compile=}")
out = model.generate(**inp, do_sample=False, generation_config=genconf, cache_implementation="static", max_new_tokens=100)
print(tokenizer.decode(out[0].tolist()))
print("\nGenerating with PyTorch eager:")
transformers_full = benchmark_n(2, generate, model, inp, device=device)
transformers_overhead = benchmark_n(2, generate, model, inp, transformers_compile='overhead', device=device)
transformers_no_compile = benchmark_n(2, generate, model, inp, transformers_compile='none', device=device)
recipe = thunder.recipes.HFTransformers()
thunder_model = thunder.compile(
model,
recipe=recipe,
# plugins=thunder.plugins.ReduceOverhead(), # CUDAGraphs will produce garbage output on main.
)
print("\nGenerating with Thunder:")
thunder_time = benchmark_n(2, generate, thunder_model, inp, device=device)
print(f"Transformers full options: {transformers_full:.2f}ms")
print(f"Transformers overhead: {transformers_overhead:.2f}ms")
print(f"Transformers plain eager: {transformers_no_compile:.2f}ms")
print(f"Thunder: {thunder_time:.2f}ms")
To get a first look at what's going on:
with torch.profiler.profile(with_stack=True) as prof:
out = thunder_model.generate(**inp, do_sample=False, cache_implementation="static", max_new_tokens=5)
prof.export_chrome_trace('thunder.json')
print(prof.key_averages().table(sort_by="self_device_time_total"))
Expected behavior
Get within ~20% between thunder and transformers overhead (no cudagraphs). Ideally get CUDAGraphs-based also within ~20% and working generation.
Environment
PyTorch 2.8.0 cuda 12.8, nvfuser-cu128-torch2.8, H100
As an update, replacing thunder_model with
recipe = thunder.recipes.HFTransformers()
recipe.executor_names = [
'nvfuser',
'inplace_index_copy_ex',
'sdpa_mask_transform_ex',
]
thunder_model = thunder.compile(
model,
recipe=recipe,
# plugins=thunder.plugins.ReduceOverhead(), # CUDAGraphs will produce garbage output on main.
)
i.e. removing the cudnn and sdpa executors, we get to
Transformers full options: 479.58ms
Transformers overhead: 674.12ms
Transformers plain eager: 1400.24ms
Thunder: 1271.31ms
I see a few nvFuser region with only one copy operation in the execution trace.
[t3251] = nvFusion44(t5356, t5355)
# t3251 = prims.copy_(t5356, t5355, grad_enabled=False) # t3251: "cuda:0 bf16[1, 8, 108, 64]"
Maybe, we should avoid such regions.
Patch to avoid nvFuser executor creating a region with only copy_.
diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index 2af424cc..c0047281 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -884,6 +884,11 @@ class nvFuserExecutor(FusionExecutor):
# if len(bsyms) > 1:
region = Region(producers, consumers, bsyms)
+ if len(bsyms) == 1:
+ if bsyms[0].sym.id == PrimIDs.COPY_:
+ fused_bsyms.extend(bsyms)
+ continue
+
nv_enable_shape_only_fusion: None | bool = get_compile_option(
"nv_enable_shape_only_fusion",
"Allow nvFuser to create Fusion with shape only operations. Defaults to False.",
Numbers:
# Transformers full options: 371.01ms
# Transformers overhead: 392.79ms
# Transformers plain eager: 647.86ms
# Thunder: 542.48ms
# Thunder: 507.02ms (with patch)