vllm
vllm copied to clipboard
Recent vLLMs ask for too much memory: ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.
Since vLLM 0.2.5, we can't even run llama-2 70B 4bit AWQ on 4*A10G anymore, have to use old vLLM. Similar problems even trying to be two 7b models on 80B A100.
For small models, like 7b with 4k tokens, vLLM fails for "cache blocks" even though alot more memory is left.
E.g. building docker image with cuda 11.8 and vllm 0.2.5 or 0.2.6 and running like:
port=5001
tokens=8192
docker run -d \
--runtime=nvidia \
--gpus '"device=1"' \
--shm-size=10.24gb \
-p $port:$port \
--entrypoint /h2ogpt_conda/vllm_env/bin/python3.10 \
-e NCCL_IGNORE_DISABLED_P2P=1 \
-v /etc/passwd:/etc/passwd:ro \
-v /etc/group:/etc/group:ro \
-u `id -u`:`id -g` \
-v "${HOME}"/.cache:/workspace/.cache \
--network host \
gcr.io/vorvan/h2oai/h2ogpt-runtime:0.1.0 -m vllm.entrypoints.openai.api_server \
--port=$port \
--host=0.0.0.0 \
--model=defog/sqlcoder2 \
--seed 1234 \
--trust-remote-code \
--max-num-batched-tokens $tokens \
--max-model-len=$tokens \
--gpu-memory-utilization 0.4 \
--download-dir=/workspace/.cache/huggingface/hub &>> logs.vllm_server.sqlcoder2.txt
port=5002
tokens=4096
docker run -d \
--runtime=nvidia \
--gpus '"device=1"' \
--shm-size=10.24gb \
-p $port:$port \
--entrypoint /h2ogpt_conda/vllm_env/bin/python3.10 \
-e NCCL_IGNORE_DISABLED_P2P=1 \
-v /etc/passwd:/etc/passwd:ro \
-v /etc/group:/etc/group:ro \
-u `id -u`:`id -g` \
-v "${HOME}"/.cache:/workspace/.cache \
--network host \
gcr.io/vorvan/h2oai/h2ogpt-runtime:0.1.0 -m vllm.entrypoints.openai.api_server \
--port=$port \
--host=0.0.0.0 \
--model=NumbersStation/nsql-llama-2-7B \
--seed 1234 \
--trust-remote-code \
--max-num-batched-tokens $tokens \
--gpu-memory-utilization 0.6 \
--max-model-len=$tokens \
--download-dir=/workspace/.cache/huggingface/hub &>> logs.vllm_server.nsql7b.txt
works. However, if the 2nd model was to have 0.4, one gets:
Traceback (most recent call last):
File "/h2ogpt_conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/h2ogpt_conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 729, in <module>
engine = AsyncLLMEngine.from_engine_args(engine_args)
File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 496, in from_engine_args
engine = cls(parallel_config.worker_use_ray,
File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 269, in __init__
self.engine = self._init_engine(*args, **kwargs)
File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 314, in _init_engine
return engine_class(*args, **kwargs)
File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 113, in __init__
self._init_cache()
File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 227, in _init_cache
raise ValueError("No available memory for the cache blocks. "
ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.
However, with 0.6 util from before, here is what GPU looks like:
Sun Dec 24 02:45:53 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100 80GB PCIe Off | 00000000:00:06.0 Off | 0 |
| N/A 43C P0 72W / 300W | 70917MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA A100 80GB PCIe Off | 00000000:00:07.0 Off | 0 |
| N/A 45C P0 66W / 300W | 49136MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 6232 C /h2ogpt_conda/vllm_env/bin/python3.10 70892MiB |
| 1 N/A N/A 6966 C /h2ogpt_conda/vllm_env/bin/python3.10 32430MiB |
| 1 N/A N/A 7685 C /h2ogpt_conda/vllm_env/bin/python3.10 16670MiB |
Ignore GPU=0.
So 0.6 util is 17GB, why would 0.4 util out of 80GB be a problem?
vLLM 0.2.6 added cuda graph support, which is enabled by default (probably not a good decision)
CUDA graph introduces a bit more memory overhead. Try to see if adding --enforce-eager
flag helps. This flag disables CUDA graph execution.
Thanks for responding. However, we had problems starting with 0.2.5.
If you need a specific snapshot or something for 4*A10G using 70B AWQ on 0.2.4 vs. 0.2.5 let me know. Or what kind of repro do you need?
Oh I see. Sorry for not reading your issue carefully. vLLM 0.2.5 changed the way the memory is profiled with #2031. While the new profiling method is more accurate, it didn't seem to take account for multiple instances running together or GPU memory usage by other processes. https://github.com/vllm-project/vllm/blob/1db83e31a2468cae37f326a642c0a4c4edbb5e4f/vllm/worker/worker.py#L100
Here, vLLM basically thinks that any occupied GPU memory is attributed to the current running instance, and thus will calculate the number of available blocks based on that. This may explain the problem when running 2 7b models on one GPU. Not quite sure about the 4xA10G use case though. Is the GPU empty or shared by other processes for that case?
Just tried to write a fix. You can try it out: #2249
Our biggest issue is clean GPUs four A10G 70b AWQ. Nothing else on GPUs
You could change the vllm/worker/worker.py
like this (see #@ note):
def load_model(self):
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() #@ add
self.gpu_mem_pre_occupied = total_gpu_memory - free_gpu_memory #@ add
self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
self,
block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
) -> Tuple[int, int]:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config)
#@ add the self.gpu_mem_pre_occupied to fix the evaluation
num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory + self.gpu_mem_pre_occupied) //
cache_block_size)
...
Then, the total_gpu_memory * gpu_memory_utilization
will be real memory you want to allocate to this model according to your max memory of GPU, unaffected by any other models already loaded.
We are having the exact same issue on our end, cache usage grows and consumes more than the allocated gpu_memory_utilization, even by using enforce-eager
.
We had the same problem before with 0.2.1
having the same issue on cuda 11.8 and vllm 0.2.5 and 0.2.6
same here
Same issue -- starting with vllm 0.2.5
same issue when use vllm 0.2.6
same here
same here
same here
same here
@Snowdar @hanzhi713 et al. I want to be clear again. The primary issue is that even single sharded model across GPUs no longer works. Forget about multiple models per GPU for now.
That is, on AWS 4*A10G, vLLM 0.2.4 and lower work perfectly fine and leave plenty of room without any failure.
However, on 0.2.5+ no matter any settings of gpu utilitization etc., never will llama 70B AWQ model fit on the 4 A10G while before it was perfectly fine (even under heavy use for long periods).
I'm working on v0.2.5 now and found this issue due to the same reason. My case is deploying a 70B BF16 model on 8xA100-40GB GPUs. I inserted logs to worker.py
to get the sense about how this error came from:
torch.cuda.empty_cache()
# Here shows the free memory is ~22GB per GPU. This is expected given 40-(70GB*2)/8=22.5
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# Here shows the free memory is only 0.26 GB per GPU. Looks like "profile_run()" consumes all memory
# even I don't know why for now.
I dived in a bit and here are some findings:
- When serving large models (e.g. 70B), the model forward itself introduces memory fragmentation. I logged the free memory after each decoder layer and found that the free memory reduces after every layer. In the case of 70B model, after 80 layers the free memory is only ~2GBs our of 40 GBs per GPU.
- Profile run samples top_k=vocab-1. This results in a bit high memory usage when vocab size is large.
- GPU cache block estimation does not consider fragmentation. Combining the above 2, the free memory is less than 1GB, which results in a very small batch or even no available GPU blocks to be used for kv cache.
My temporary solution is as follows:
- Manually add
torch.cuda.empty_cache()
inworker.py
before the linefree_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
. This removes the impact of fragmentation. - The above change makes OOM possible when actual serving the model, because
empty_cache()
also removes the impact of intermediate tensors when running forward pass. As a result, tuning the--gpu-memory-utilization
becomes more important, as we have to use it to cover the forward intermediate tensors. Here are my testing results with different util values:- 0.8: 2828 = 45248 tokens
- 0.9: 3644 = 58304 tokens
- 1.0: OOM
Yet another version of this problem is that 01-ai/Yi-34B-Chat used to work perfectly fine on 4*H100 80GB when run like:
python -m vllm.entrypoints.openai.api_server --port=5000 --host=0.0.0.0 --model 01-ai/Yi-34B-Chat --seed 1234 --tensor-parallel-size=4 --trust-remote-code
But now it doesn't since 0.2.5+ including 0.2.7. Get instead:
INFO 01-16 14:40:02 api_server.py:750] args: Namespace(host='0.0.0.0', port=5000, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], served_model_name=None, ch>
2024-01-16 14:40:04,623 INFO worker.py:1673 -- Started a local Ray instance.
INFO 01-16 14:40:06 llm_engine.py:70] Initializing an LLM engine with config: model='01-ai/Yi-34B-Chat', tokenizer='01-ai/Yi-34B-Chat', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust>
INFO 01-16 14:41:00 llm_engine.py:294] # GPU blocks: 0, # CPU blocks: 4369
Traceback (most recent call last):
File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 760, in <module>
engine = AsyncLLMEngine.from_engine_args(engine_args)
File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 544, in from_engine_args
engine = cls(parallel_config.worker_use_ray,
File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 274, in __init__
self.engine = self._init_engine(*args, **kwargs)
File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 319, in _init_engine
return engine_class(*args, **kwargs)
File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 114, in __init__
self._init_cache()
File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 298, in _init_cache
raise ValueError("No available memory for the cache blocks. "
ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.
When can we expect a fix? It seems a pretty serious bug.
BTW, curiously, I ran the same exact command a second time (both times nothing on the GPUs) and second time didn't hit the error. So maybe there is a race in the memory size detection in vLLM.
I am trying to run this command as given in docs
python3 -u -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --model mistralai/Mistral-7B-Instruct-v0.2
It gives me an error
File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 279, in _init_cache
raise ValueError("No available memory for the cache blocks. "
ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.
What should I do? I am running a runpod with 1x RTX 4000 Ada
I am trying to run this command as given in docs
python3 -u -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --model mistralai/Mistral-7B-Instruct-v0.2
It gives me an error
File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 279, in _init_cache raise ValueError("No available memory for the cache blocks. " ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.
What should I do? I am running a runpod with 1x RTX 4000 Ada
I have upgraded to 1x A100 and now passing --gpu_memory_utilization 0.8 param, but still same error
The issue was resolved by adding --tensor-parallel-size 1
The reason it helped, because I am running runpod instance which as I understand, gives me access only to requested GPUs attached to physical machine.
I also encountered this problem (i.e., OOM, or too few KV cache blocks) on 70B LLM with v0.2.7 and dived in a bit. Here are my findings.
My dev environment: 8 A800 GPUs machine with CUDA 11.3.
Working Solution: Use peak_memory = torch.cuda.max_memory_allocated()
in worker.py
(basically revert https://github.com/vllm-project/vllm/pull/2031).
Another Working Solution: Update to torch==2.1.2.
Analysis: There are evidences of more memory fragmentation when tp > 1, see here and here. Seems that the record_stream
(called for NCCL communication) makes the cached memory blocks of activations cannot be reused, so the memory consumption of one forward grows substantially. Setting TORCH_NCCL_AVOID_RECORD_STREAMS=1
can fix this problem, by stashing the references to the related memory blocks and do proper synchronization, without calling record_stream
. This environment variable is already set in vLLM-v0.2.7, but the PyTorch version on my dev machine is 2.0.1, which has not been introduced TORCH_NCCL_AVOID_RECORD_STREAMS
. Updating to torch==2.1.2 solves the problem.
@ZiyueHuang I have pytorch 2.1.2 and vllm 0.2.7 and this wasn't solved by that.
@ZiyueHuang I have pytorch 2.1.2 and vllm 0.2.7 and this wasn't solved by that.
@pseudotensor How about trying reverting https://github.com/vllm-project/vllm/pull/2031?
@ZiyueHuang Yes, I'm trying that now.
This issue was closed automatically by github, that was not correct.
Reverting avoided the title message, but it went GPU OOM unlike 0.2.4 with same long-context query. FYI @sh1ng
I've upgraded to latest ray/ray-llm and now having this issue. Is there a known hot fix? gpu_memory_utilization
doesn't help at all.
Trying to run TheBloke/70b with this config
engine_config:
model_id: TheBloke/Llama-2-70B-chat-AWQ
hf_model_id: TheBloke/Llama-2-70B-chat-AWQ
type: VLLMEngine
engine_kwargs:
quantization: awq
max_num_batched_tokens: 32768
max_num_seqs: 256
gpu_memory_utilization: 0.90
max_total_tokens: 4096
FYI @pseudotensor
I've tested the memory footprint of 0.2.4
and 0.2.7
and this is my finding:
- I'm sure that https://github.com/vllm-project/vllm/pull/2031 is correct and should be there.
before #2031 non-torch-related allocations were completely ignored.|<-------------------------------------total GPU memory---------------------------------------->| |<---Allocated by torch allocator--->|<--Allocated by NCCL, cuBLAS, etc-->|<--free GPU memory-->|
- #2031 just computes it correctly. We still need to fix peak memory consumption in case of multiple memory-consuming processes.
- Running
0.2.4
and0.2.7
consume exactly the same amount of memory(measuring by old and new way) by a model. - Changing
nccl
version doesn't change memory consumption significantly (~10MB). - When using
--enforce-eager
the memory consumption is a little bit lower. - Using
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
helps and also makes execution w and w\o--enforce-eager
identical. I'm not sure how stable it's as it's marked as experimental. - I believe that by carefully tuning
gpu_memory_utilization
we can get the original behavior as I don't see an increase in memory consumption. - It's better to fully dedicate a sub-set of GPUs to a single vllm-model and don't share a GPU across multiple models as NCCL's, cuBLAS's, torch's overhead will multiply.