TriForce
TriForce copied to clipboard
Out of memory on H800
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 124928 --budget 4096 \
--chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00, 3.18it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 124928
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################
[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 144, in initialize_cuda_graph
self.callables[gamma_offset] = draft_run_capture_graph(
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 83, in draft_run_capture_graph
static_logits = engine.draft_run(input_ids=static_input_ids, gamma_offset=gamma_offset, probs=probs, temperature=temperature, top_p=top_p)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 54, in draft_run
logits = self.draft(input_ids=input_ids, kv_cache=self.draft_cache, graph_cache=self.draft_cache, gamma_offset=gamma_offset).logits
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 340, in forward
outputs = self.model(
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 301, in forward
layer_outputs = decoder_layer(
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 220, in forward
hidden_states = self.self_attn(
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 141, in forward
query_states = self.q_proj(hidden_states)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU
Hello, may I ask the memory for your device? You can try to decrease prefill from 124928 to 122880 to see if it is still OOM. The code can run on a single A100-80GB.
122880 I got this error
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096 --chunk_size 8 --t
op_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.85s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00, 2.82it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################
[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 154, in initialize_cuda_graph
self.callable_model_verify = model_verify_capture_graph(
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 111, in model_verify_capture_graph
static_logits = engine.model_verify(input_ids=static_input_ids, position_ids=static_position_ids, probs=probs, temperature=temperature, top_p=top_p)
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 63, in model_verify
logits = self.model(input_ids=input_ids, kv_cache=self.kv_cache, graph_cache=self.graph_cache, position_ids=position_ids, spec=True).logits
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 396, in forward
outputs = self.model(
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 355, in forward
layer_outputs = decoder_layer(
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 276, in forward
hidden_states = self.self_attn(
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 222, in forward
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 182, in apply_rotary_pos_emb
q_embed = (q * cos) + (rotate_half(q) * sin)
RuntimeError: The size of tensor a (32) must match the size of tensor b (131072) at non-singleton dimension 1
nvidia-smi
Wed Jun 26 10:40:05 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H800 On | 00000000:61:00.0 Off | 0 |
| N/A 30C P0 70W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
What is your transformers version? Can you set it to transformers==4.37.2 since apply_rotary_pos_emb api changes for recent versions?
thanks, I think it's more likely to be an environment problem
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096 --chunk_size 8
--top_p 0.9 --temp 0.6 --gamma 6
/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.65s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00, 3.15it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################
[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
[model verify] capturing graph for spec len 6 (probs=True, temp=0.6, top_p=0.9)...
[Full Cache] Cached: 0 | Budget: 123152
[Retrieval Cache] Budget: 4096 | PreFill: 122880 | Chunk Size: 8 | Chunks: 15360 | Select Sets: 512
[StreamingLLM Cache] Start Size: 16 | Recent Size: 234 | Gamma: 6 | Real Budget: 259 | Cached: 0
tokenized_prompts length: 20
Autoregressive Warmup: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:37<00:00, 37.44s/it]
Autoregressive Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:35<00:00, 35.70s/it]
[Autoregressive] average latency: 31.812156550586224 ms
TriForce Warmup: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:44<00:00, 34.95s/it]
TriForce Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [11:08<00:00, 33.43s/it]
average acceptance rate (NOT per token): 0.6893210335795382
[TriForce] average latency: 17.238558956780214 ms
[E2E Speedup]: 1.8454069525384529
I changed the version of the transformer because there are other environment issues, such as:
- Torch uses the default CUDA version (12.1), but mine is 11.8 (I installed it manually).
- The latest flash-attn has this problem: https://github.com/Dao-AILab/flash-attention/issues/451, so I switched to an older version (2.5.2). If possible, could you please add a note in the README about the version of your flash_attn and how to install the correct version of torch?
Yeah I am using CUDA 12.1. Here is my flash_attn version.
>>> import torch
>>> torch.__version__
'2.2.1+cu121'
>>> import flash_attn
>>> flash_attn.__version__
'2.5.7'
I have added a FAQ in the README :)