TriForce icon indicating copy to clipboard operation
TriForce copied to clipboard

Out of memory on H800

Open Lucas-TY opened this issue 1 year ago • 8 comments

CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 124928 --budget 4096 \
 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.18it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 124928
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
  File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
    graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 144, in initialize_cuda_graph
    self.callables[gamma_offset] = draft_run_capture_graph(
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 83, in draft_run_capture_graph
    static_logits = engine.draft_run(input_ids=static_input_ids, gamma_offset=gamma_offset, probs=probs, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 54, in draft_run
    logits = self.draft(input_ids=input_ids, kv_cache=self.draft_cache, graph_cache=self.draft_cache, gamma_offset=gamma_offset).logits
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 340, in forward
    outputs = self.model(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 301, in forward
    layer_outputs = decoder_layer(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 220, in forward
    hidden_states = self.self_attn(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 141, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 

Lucas-TY avatar Jun 25 '24 21:06 Lucas-TY

Hello, may I ask the memory for your device? You can try to decrease prefill from 124928 to 122880 to see if it is still OOM. The code can run on a single A100-80GB.

preminstrel avatar Jun 25 '24 21:06 preminstrel

122880 I got this error

CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096  --chunk_size 8 --t
op_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.85s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.82it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
  File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
    graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 154, in initialize_cuda_graph
    self.callable_model_verify = model_verify_capture_graph(
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 111, in model_verify_capture_graph
    static_logits = engine.model_verify(input_ids=static_input_ids, position_ids=static_position_ids, probs=probs, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 63, in model_verify
    logits = self.model(input_ids=input_ids, kv_cache=self.kv_cache, graph_cache=self.graph_cache, position_ids=position_ids, spec=True).logits
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 396, in forward
    outputs = self.model(
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 355, in forward
    layer_outputs = decoder_layer(
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 276, in forward
    hidden_states = self.self_attn(
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 222, in forward
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 182, in apply_rotary_pos_emb
    q_embed = (q * cos) + (rotate_half(q) * sin)
RuntimeError: The size of tensor a (32) must match the size of tensor b (131072) at non-singleton dimension 1

Lucas-TY avatar Jun 26 '24 02:06 Lucas-TY

nvidia-smi
Wed Jun 26 10:40:05 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H800                    On  | 00000000:61:00.0 Off |                    0 |
| N/A   30C    P0              70W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

Lucas-TY avatar Jun 26 '24 02:06 Lucas-TY

What is your transformers version? Can you set it to transformers==4.37.2 since apply_rotary_pos_emb api changes for recent versions?

preminstrel avatar Jun 26 '24 02:06 preminstrel

thanks, I think it's more likely to be an environment problem

CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096  --chunk_size 8 
--top_p 0.9 --temp 0.6 --gamma 6
/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.65s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.15it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
[model verify] capturing graph for spec len 6 (probs=True, temp=0.6, top_p=0.9)...
[Full Cache] Cached: 0 | Budget: 123152
[Retrieval Cache] Budget: 4096  | PreFill: 122880  | Chunk Size: 8  | Chunks: 15360  | Select Sets: 512
[StreamingLLM Cache] Start Size: 16 | Recent Size: 234 | Gamma: 6 | Real Budget: 259 | Cached: 0
tokenized_prompts length: 20
Autoregressive Warmup: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:37<00:00, 37.44s/it]
Autoregressive Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:35<00:00, 35.70s/it]
[Autoregressive] average latency: 31.812156550586224 ms
TriForce Warmup: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:44<00:00, 34.95s/it]
TriForce Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [11:08<00:00, 33.43s/it]
average acceptance rate (NOT per token): 0.6893210335795382
[TriForce] average latency: 17.238558956780214 ms
[E2E Speedup]: 1.8454069525384529

Lucas-TY avatar Jun 26 '24 04:06 Lucas-TY

I changed the version of the transformer because there are other environment issues, such as:

  • Torch uses the default CUDA version (12.1), but mine is 11.8 (I installed it manually).
  • The latest flash-attn has this problem: https://github.com/Dao-AILab/flash-attention/issues/451, so I switched to an older version (2.5.2). If possible, could you please add a note in the README about the version of your flash_attn and how to install the correct version of torch?

Lucas-TY avatar Jun 26 '24 04:06 Lucas-TY

Yeah I am using CUDA 12.1. Here is my flash_attn version.

>>> import torch
>>> torch.__version__
'2.2.1+cu121'
>>> import flash_attn
>>> flash_attn.__version__
'2.5.7'

preminstrel avatar Jun 26 '24 04:06 preminstrel

I have added a FAQ in the README :)

preminstrel avatar Jun 26 '24 04:06 preminstrel