LWM icon indicating copy to clipboard operation
LWM copied to clipboard

Memory requirements

Open loretoparisi opened this issue 1 year ago • 5 comments

It would be worth to provide the measured memory requirements for inference Text Models at 32K, 128K,256K,512K and 1M tokens context window in both PyTorch and JAX.

loretoparisi avatar Feb 14 '24 21:02 loretoparisi

If using vLLM for inference (PyTorch model, FP16), I believe we used:

  • 1 80GB A100 for 32K
  • 2 80GB A100s for 128K
  • 4 80GB A100s for 256K
  • 8 80GB A100s for 512K

For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory.

For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with).

wilson1yan avatar Feb 14 '24 22:02 wilson1yan

Any recommendation to run the model on smaller GPUs (T4). It runs out of memory (jax).

blazorin avatar Feb 21 '24 05:02 blazorin

@wilson1yan Can you share the shell/bash script for setting up the inference server via vLLM for PyTorch model, FP16?

If using vLLM for inference (PyTorch model, FP16), I believe we used:

* 1 80GB A100 for 32K

* 2 80GB A100s for 128K

* 4 80GB A100s for 256K

* 8 80GB A100s for 512K

For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory.

For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with).

Playerrrrr avatar Mar 10 '24 08:03 Playerrrrr

I’m thinking an attention kernel optimization like top-k would be appropriate here. Could a user calculate their own position_ids and pass a subset of the tokens, maybe make multiple passes and drop tokens that don’t impact the results?

xloem avatar Mar 10 '24 14:03 xloem

Aren't those requirements a bit high in case of 7B w/ 32k context? Mistral 7B 0.2 (32k context) works absolutely fine on consumer grade GPUs (especially when using quantized versions, like high quality Q6_K GGUFs).

MoonRide303 avatar Apr 08 '24 08:04 MoonRide303