transformers
transformers copied to clipboard
Skipping cudagraphs for unknown reason
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
-
transformers
version: 4.41.2 - Platform: Linux-5.15.0-112-generic-x86_64-with-glibc2.35
- Python version: 3.10.13
- Huggingface_hub version: 0.23.4
- Safetensors version: 0.4.2
- Accelerate version: 0.28.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.2 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
@ArthurZucker
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
I read issue 30055 and issue 30351, and Llama works well with cache_implementation="static"
. However, I am trying to use torch.compile
for other models such as pythia
and phi-2
where the cache_implementation="static"
is not appliable, and it will produce errors like:
[2024-06-26 18:18:57,065] [0/0] torch._inductor.fx_passes.split_cat: [WARNING] example value absent for node: cat_3
[2024-06-26 18:18:57,065] [0/0] torch._inductor.fx_passes.split_cat: [WARNING] example value absent for node: cat_2
[2024-06-26 18:18:57,065] [0/0] torch._inductor.fx_passes.split_cat: [WARNING] example value absent for node: cat_1
[2024-06-26 18:18:57,065] [0/0] torch._inductor.fx_passes.split_cat: [WARNING] example value absent for node: cat
skipping cudagraphs for unknown reason
...
...
File "/tmp/torchinductor_hc29225/bi/cbiig6bqpidhtncuswvfxwqqjwoiiswlwlrnh7eobbwm4wjlvpts.py", line 15465, in call
extern_kernels.addmm(arg4_1, reinterpret_tensor(buf3, (16, 2560), (2560, 1), 0), reinterpret_tensor(arg3_1, (2560, 2560), (1, 2560), 0), alpha=1, beta=
1, out=buf4)
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
Here is my code for reproducing the errors.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Llama works well with cache_implementation="static", but other types of models do not have the configuration.
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="sdpa", token=access_token).cuda().eval()
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token=access_token)
# model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-2.8b" , torch_dtype=torch.float16, trust_remote_code=True).cuda().eval()
# tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype=torch.float16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
max_new_tokens = 64
fn = lambda: model.generate(
**inputs,
do_sample=False,
# cache_implementation="static", # this only works for Llama-2
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.pad_token_id,
temperature=None,
top_p=None
)
fn()
Expected behavior
The models such as pythia
and phi-2
can run with torch.compile
and a clear latency improvement can be observed.