lightllm icon indicating copy to clipboard operation
lightllm copied to clipboard

my A800 80G*8

Open weisihao opened this issue 2 years ago • 7 comments

How can this problem be solved??

self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.14 GiB (GPU 0; 79.35 GiB total capacity; 77.83 GiB already allocated; 711.19 MiB free; 77.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

weisihao avatar Aug 03 '23 06:08 weisihao

@weisihao model weight size? the max_total_token_num need to set smaller.

hiworldwzj avatar Aug 03 '23 07:08 hiworldwzj

model is Llama-2-13b-chat-hf, I just tried to set max_total_token_num to 6000 and it worked,thanks!

weisihao avatar Aug 03 '23 07:08 weisihao

Is there any connection between the size of the model and the (max_total_token_num parameter)? and how should I set this parameter if I later test it with 70B's llama2?

weisihao avatar Aug 03 '23 07:08 weisihao

@weisihao I will commit a update about the way to set this arg today. show here:

##### --max_total_token_num

default is 6000,  
the total token num the gpu and model can support, a sample about how to set this arg:   
gpu: use 2 A100 80G, (--tp 2)  
model: llama-7b,  
dtype: fp16,  
llama-7b hidden_size is 4096, layers num is 32,   
the gpu mem left after gpu load all weights,   

80 * 2 - 7 * 2 = 146G  

gpu mem for one Token kv cache:   

4096 * 2 * 2 * 32 / 1024 / 1024 / 1024 =  0.000488281G  

the max token num:    

146 / 0.000488281 ≈ 299008  

Of course, this value cannot be directly set, because extra gpu mem will be used during the model inference,We need to multiply this value by a ratio:  

max_total_token_num = 299008 * ratio   

We recommend setting the ratio between 0.8 and 0.9, perhaps slightly higher. if OOM error happens, you can reduce the ratio or arg "batch_max_tokens".  

llama2-70b use GQA feature, the kv cache for one token is different , you can try.

hiworldwzj avatar Aug 03 '23 08:08 hiworldwzj

Thanks for this amazing work, i want to know if u have supported the GQA feature for Llama2-70B in this repo? I am trying to support GQA feature in FasterTransformer-llama, https://github.com/NVIDIA/FasterTransformer/issues/506 I will appreciate any refenerce of GQA.

CN-COTER avatar Aug 09 '23 07:08 CN-COTER

@CN-COTER llama2-70B has been tested, GQA is easy to add in triton kernel.

hiworldwzj avatar Aug 10 '23 08:08 hiworldwzj

Thank u for reply, i refer to triton_kernel in lightlm and got it.😀

hiworldwzj @.***> 于2023年8月10日周四 16:17写道:

@CN-COTER https://github.com/CN-COTER llama2-70B has been tested, GQA is easy to add in triton kernel.

— Reply to this email directly, view it on GitHub https://github.com/ModelTC/lightllm/issues/13#issuecomment-1672772961, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFU2CZEUY6CZTY4KBPHGVM3XUSKIHANCNFSM6AAAAAA3CJRYBM . You are receiving this because you were mentioned.Message ID: @.***>

CN-COTER avatar Aug 10 '23 09:08 CN-COTER