blog
blog copied to clipboard
Llama3.1 inference memory requirements
https://huggingface.co/blog/llama31#inference-memory-requirements Please tell me about the calculation of inference memory requirements for Llama 3.1 in this post.
The table below shows an excerpt of the KV cache size for FP16.
| Model Size | 1k tokens | 16k tokens | 128k tokens |
|---|---|---|---|
| 8B | 0.125 GB | 1.95 GB | 15.62 GB |
| 70B | 0.313 GB | 4.88 GB | 39.06 GB |
| 405B | 0.984 GB | 15.38 GB | 123.05 GB |
I used the formula in this article to do my own calculations.
The formula in the article as follows:
This shows the size of the KV cache per token, where the first factor of 2 accounts for the K and V matrices.
num_layers and num_heads and dim_heads refer to the values in the Llama3 paper.
For example, for the 8B model with 16k tokens and 128k tokens, the calculation is as follows and matches the numbers in the table above.
16000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 1.953125
128000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 15.625
However, if we calculate 16k tokens and 128k tokens in the same way for the 405B model, the numbers do not match those in the table above. The calculated values seem to be half of the values in the table.
16000 * (2 * 126 * 8 * (16384/128) * 2) / 1024**3
# 7.6904296875
128000 * (2 * 126 * 8 * (16384/128) * 2) / 1024**3
# 61.5234375
Am I misunderstanding something? Or is there another factor that needs to be taken into account for the 405B model?
Also, for 1k tokens, the numbers are slightly different. Is it calculated as 1024 in the table?
1000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 0.1220703125
1024 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 0.125
Thank you!
This chart in llama3 paper has something wrong. The key/value cache head number for 405B model is 16 rather than 8. You can find the answer in this link
@ZeusXuan Thank you for the comment! I read the reddit post. Does this mean that the number of KV heads on the 405B model was 16, but has been changed to 8, the same as in the white paper? I found the following link to the commit that fixes it to 8 kv heads. https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-FP8/discussions/15
Hi @satojkovic ,
I am writing to follow up if you have figured out the discrepancy between your calculation and the number for post/. I believe your calculation makes sense to me. I think it's not super clear if 128K context length means 128 *1000 or 128 * 1024.
Best, Siyuan
Also, https://lmcache.ai/kv_cache_calculator.html