LWM
LWM copied to clipboard
Memory requirements
It would be worth to provide the measured memory requirements for inference Text Models at 32K, 128K,256K,512K and 1M tokens context window in both PyTorch and JAX.
If using vLLM for inference (PyTorch model, FP16), I believe we used:
- 1 80GB A100 for 32K
- 2 80GB A100s for 128K
- 4 80GB A100s for 256K
- 8 80GB A100s for 512K
For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory.
For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with).
Any recommendation to run the model on smaller GPUs (T4). It runs out of memory (jax).
@wilson1yan Can you share the shell/bash script for setting up the inference server via vLLM for PyTorch model, FP16?
If using vLLM for inference (PyTorch model, FP16), I believe we used:
* 1 80GB A100 for 32K * 2 80GB A100s for 128K * 4 80GB A100s for 256K * 8 80GB A100s for 512K
For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory.
For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with).
I’m thinking an attention kernel optimization like top-k would be appropriate here. Could a user calculate their own position_ids and pass a subset of the tokens, maybe make multiple passes and drop tokens that don’t impact the results?
Aren't those requirements a bit high in case of 7B w/ 32k context? Mistral 7B 0.2 (32k context) works absolutely fine on consumer grade GPUs (especially when using quantized versions, like high quality Q6_K GGUFs).