start mutiple models
We are trying kvcached, but sometimes failed to start both 2 qwen3/0.6B instances with vllm/vllm-openai:v0.8.5.post1 (can't use v0.11.1 since cuda version < 12.8) due to memory allocation. Could you give suggestions to start mutiple models? Like set gpu-memory-utilization or any other parameters? We set gpu-memory-utilization 0.4, but looked not always work. Thanks!
Hi @inforly, thank you for trying kvcached!
Would you mind sharing the error logs which can help us locate the problem?
At the same time, when using kvcached, we don't need to set the gpu-memory-utilization. kvcached will dynamically allocate memory for KV cache.
We will also try vllm/vllm-openai:v0.8.5.post1 with 2 qwen3/0.6B and let you know.
I just tested kvcached in the vllm/vllm-openai:v0.8.5.post1 docker. It seems that the current patch doesn't support this version well. Will support it ASAP. CC @ivanium
root@instance-20251016-031138:/vllm-workspace/kvcached# vllm serve meta-llama/Llama-3.2-1B --disable-log-requests --no-enable-prefix-caching --port=12346
INFO 10-26 08:17:54 [__init__.py:239] Automatically detected platform cuda.
[kvcached][INFO][2025-10-26 08:17:56][patch_base.py:98] Applying 6 patches for vllm
[kvcached][INFO][2025-10-26 08:17:58][version_utils.py:189] Detected vllm version: 0.8.5.post1
[kvcached][INFO][2025-10-26 08:17:58][version_utils.py:189] Detected vllm version: 0.8.5.post1
[kvcached][INFO][2025-10-26 08:17:58][version_utils.py:189] Detected vllm version: 0.8.5.post1
[kvcached][WARNING][2025-10-26 08:17:58][patches.py:408] Failed to apply patch_allocation_methods
[kvcached][WARNING][2025-10-26 08:17:58][patches.py:408] Failed to apply patch_reshape_methods
[kvcached][WARNING][2025-10-26 08:17:58][patch_base.py:119] Failed to apply gpu_model_runner
[kvcached][INFO][2025-10-26 08:17:58][version_utils.py:189] Detected vllm version: 0.8.5.post1
[kvcached][ERROR][2025-10-26 08:17:58][patch_base.py:163] Could not import target module vllm.v1.core.kv_cache_coordinator for kv_cache_coordinator: No module named 'vllm.v1.core.kv_cache_coordinator'
[kvcached][WARNING][2025-10-26 08:17:58][patch_base.py:119] Failed to apply kv_cache_coordinator
[kvcached][INFO][2025-10-26 08:17:58][patch_base.py:178] Successfully patched vllm: elastic_block_pool, engine_core, gpu_worker
[kvcached][WARNING][2025-10-26 08:17:58][patch_base.py:180] Failed to patch vllm: gpu_model_runner, kv_cache_coordinator
INFO 10-26 08:17:59 [api_server.py:1043] vLLM API server version 0.8.5.post1
INFO 10-26 08:17:59 [api_server.py:1044] args: Namespace(subparser='serve', model_tag='meta-llama/Llama-3.2-1B', config='', host=None, port=12346, uvicorn_log_level='info', disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='meta-llama/Llama-3.2-1B', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, allowed_local_media_path=None, load_format='auto', download_dir=None, model_loader_extra_config={}, use_tqdm_on_load=True, config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', max_model_len=None, guided_decoding_backend='auto', reasoning_parser=None, logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=1, data_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, disable_custom_all_reduce=False, block_size=None, gpu_memory_utilization=0.9, swap_space=4, kv_cache_dtype='auto', num_gpu_blocks_override=None, enable_prefix_caching=False, prefix_caching_hash_algo='builtin', cpu_offload_gb=0, calculate_kv_scales=False, disable_sliding_window=False, use_v2_block_manager=True, seed=None, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_token=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config={}, limit_mm_per_prompt={}, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=None, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=None, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', speculative_config=None, ignore_patterns=[], served_model_name=None, qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, max_num_batched_tokens=None, max_num_seqs=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, num_lookahead_slots=0, scheduler_delay_factor=0.0, preemption_mode=None, num_scheduler_steps=1, multi_step_stream_outputs=True, scheduling_policy='fcfs', enable_chunked_prefill=None, disable_chunked_mm_input=False, scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, additional_config=None, enable_reasoning=False, disable_cascade_attn=False, disable_log_requests=True, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False, dispatch_function=<function ServeSubcommand.cmd at 0x7b12087da200>)
INFO 10-26 08:18:08 [config.py:717] This model supports multiple tasks: {'embed', 'classify', 'generate', 'score', 'reward'}. Defaulting to 'generate'.
INFO 10-26 08:18:08 [config.py:2003] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 10-26 08:18:13 [__init__.py:239] Automatically detected platform cuda.
[kvcached][INFO][2025-10-26 08:18:16][patch_base.py:98] Applying 6 patches for vllm
[kvcached][INFO][2025-10-26 08:18:16][version_utils.py:189] Detected vllm version: 0.8.5.post1
[kvcached][INFO][2025-10-26 08:18:16][version_utils.py:189] Detected vllm version: 0.8.5.post1
[kvcached][INFO][2025-10-26 08:18:16][version_utils.py:189] Detected vllm version: 0.8.5.post1
[kvcached][WARNING][2025-10-26 08:18:16][patches.py:408] Failed to apply patch_allocation_methods
[kvcached][WARNING][2025-10-26 08:18:16][patches.py:408] Failed to apply patch_reshape_methods
[kvcached][WARNING][2025-10-26 08:18:16][patch_base.py:119] Failed to apply gpu_model_runner
[kvcached][INFO][2025-10-26 08:18:16][version_utils.py:189] Detected vllm version: 0.8.5.post1
[kvcached][ERROR][2025-10-26 08:18:16][patch_base.py:163] Could not import target module vllm.v1.core.kv_cache_coordinator for kv_cache_coordinator: No module named 'vllm.v1.core.kv_cache_coordinator'
[kvcached][WARNING][2025-10-26 08:18:16][patch_base.py:119] Failed to apply kv_cache_coordinator
[kvcached][INFO][2025-10-26 08:18:16][patch_base.py:178] Successfully patched vllm: elastic_block_pool, engine_core, gpu_worker
[kvcached][WARNING][2025-10-26 08:18:16][patch_base.py:180] Failed to patch vllm: gpu_model_runner, kv_cache_coordinator
INFO 10-26 08:18:16 [core.py:58] Initializing a V1 LLM engine (v0.8.5.post1) with config: model='meta-llama/Llama-3.2-1B', speculative_config=None, tokenizer='meta-llama/Llama-3.2-1B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=meta-llama/Llama-3.2-1B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 10-26 08:18:16 [utils.py:2522] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x73b0c9102d20>
INFO 10-26 08:18:17 [parallel_state.py:1004] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 10-26 08:18:17 [cuda.py:221] Using Flash Attention backend on V1 engine.
INFO 10-26 08:18:17 [topk_topp_sampler.py:59] Using FlashInfer for top-p & top-k sampling.
INFO 10-26 08:18:17 [gpu_model_runner.py:1329] Starting to load model meta-llama/Llama-3.2-1B...
INFO 10-26 08:18:17 [weight_utils.py:265] Using model weights format ['*.safetensors']
INFO 10-26 08:18:17 [weight_utils.py:315] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 1.41it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 1.41it/s]
INFO 10-26 08:18:18 [loader.py:458] Loading weights took 0.77 seconds
INFO 10-26 08:18:18 [gpu_model_runner.py:1347] Model loading took 2.3185 GiB and 1.144347 seconds
INFO 10-26 08:18:24 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/a897a73da8/rank_0_0 for vLLM's torch.compile
INFO 10-26 08:18:24 [backends.py:430] Dynamo bytecode transform time: 6.33 s
INFO 10-26 08:18:28 [backends.py:118] Directly load the compiled graph(s) for shape None from the cache, took 3.788 s
INFO 10-26 08:18:29 [monitor.py:33] torch.compile takes 6.33 s in total
INFO 10-26 08:18:30 [kv_cache_utils.py:634] GPU KV cache size: 550,176 tokens
INFO 10-26 08:18:30 [kv_cache_utils.py:637] Maximum concurrency for 131,072 tokens per request: 4.20x
INFO 10-26 08:18:47 [gpu_model_runner.py:1686] Graph capturing finished in 17 secs, took 0.28 GiB
INFO 10-26 08:18:47 [core.py:159] init engine (profile, create kv cache, warmup model) took 29.19 seconds
INFO 10-26 08:18:47 [core_client.py:439] Core engine process 0 ready.
WARNING 10-26 08:18:47 [config.py:1239] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.
INFO 10-26 08:18:47 [serving_chat.py:118] Using default chat sampling params from model: {'temperature': 0.6, 'top_p': 0.9}
INFO 10-26 08:18:47 [serving_completion.py:61] Using default completion sampling params from model: {'temperature': 0.6, 'top_p': 0.9}
INFO 10-26 08:18:47 [api_server.py:1090] Starting vLLM API server on http://0.0.0.0:12346
INFO 10-26 08:18:47 [launcher.py:28] Available routes are:
INFO 10-26 08:18:47 [launcher.py:36] Route: /openapi.json, Methods: HEAD, GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /docs, Methods: HEAD, GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /docs/oauth2-redirect, Methods: HEAD, GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /redoc, Methods: HEAD, GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /health, Methods: GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /load, Methods: GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /ping, Methods: POST, GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /tokenize, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /detokenize, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /v1/models, Methods: GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /version, Methods: GET
INFO 10-26 08:18:47 [launcher.py:36] Route: /v1/chat/completions, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /v1/completions, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /v1/embeddings, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /pooling, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /score, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /v1/score, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /v1/audio/transcriptions, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /rerank, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /v1/rerank, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /v2/rerank, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /invocations, Methods: POST
INFO 10-26 08:18:47 [launcher.py:36] Route: /metrics, Methods: GET
INFO: Started server process [504]
INFO: Waiting for application startup.
INFO: Application startup complete.
Oh I think our version detection does not cover 0.8.5.post1 but that could be quickly fixed. Will update shortly
Thank you @inforly for the issue! This should be fixed by #194. We'd appreciate it if you could give it another try. The fix is currently on the main branch, so you may need to install kvcached from source (https://github.com/ovg-project/kvcached?tab=readme-ov-file#install-from-source).
Also as Jiarong has pointed out, you don't have to set a value for gpu memory utilization. kvcached should be able to detect available memory and coordinate KV cache sizes dynamically.
Thanks!
@ivanium @jiarong0907 thank you very much for the quick fix! I tried running the following script without using --gpu-memory-utilization 0.4, but it still failed with a CUDA out-of-memory error.
script
apt-get update && apt-get upgrade -y && apt-get install -y git
git clone https://github.com/ovg-project/kvcached.git
cd kvcached
pip install -e . --no-build-isolation --no-cache-dir
python tools/dev_copy_pth.py
export ENABLE_KVCACHED=true
export KVCACHED_IPC_NAME=VLLM
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
nohup vllm serve Qwen/Qwen3-0.6B --port 8887 --host 0.0.0.0 --disable-log-requests --no-enable-prefix-caching > 1.log 2>&1 &
vllm serve Qwen/Qwen3-0.6B --port 8888 --host 0.0.0.0 --disable-log-requests --no-enable-prefix-caching
error
INFO 10-27 05:36:31 [kv_cache_utils.py:634] GPU KV cache size: 591,440 tokens
INFO 10-27 05:36:31 [kv_cache_utils.py:637] Maximum concurrency for 40,960 tokens per request: 14.44x
ERROR 10-27 05:36:31 [core.py:396] EngineCore failed to start.
ERROR 10-27 05:36:31 [core.py:396] Traceback (most recent call last):
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 387, in run_engine_core
ERROR 10-27 05:36:31 [core.py:396] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 10-27 05:36:31 [core.py:396] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 329, in __init__
ERROR 10-27 05:36:31 [core.py:396] super().__init__(vllm_config, executor_class, log_stats,
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 71, in __init__
ERROR 10-27 05:36:31 [core.py:396] self._initialize_kv_caches(vllm_config)
Process EngineCore_0:
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 156, in _initialize_kv_caches
ERROR 10-27 05:36:31 [core.py:396] self.model_executor.initialize_from_config(kv_cache_configs)
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 63, in initialize_from_config
ERROR 10-27 05:36:31 [core.py:396] self.collective_rpc("initialize_from_config",
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 10-27 05:36:31 [core.py:396] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 10-27 05:36:31 [core.py:396] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2456, in run_method
ERROR 10-27 05:36:31 [core.py:396] return func(*args, **kwargs)
ERROR 10-27 05:36:31 [core.py:396] ^^^^^^^^^^^^^^^^^^^^^
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 599, in initialize_from_config
ERROR 10-27 05:36:31 [core.py:396] self.worker.initialize_from_config(kv_cache_config) # type: ignore
ERROR 10-27 05:36:31 [core.py:396] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 226, in initialize_from_config
ERROR 10-27 05:36:31 [core.py:396] self.model_runner.initialize_kv_cache(kv_cache_config)
ERROR 10-27 05:36:31 [core.py:396] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1722, in initialize_kv_cache
ERROR 10-27 05:36:31 [core.py:396] kv_caches[layer_name] = torch.zeros(kv_cache_shape,
ERROR 10-27 05:36:31 [core.py:396] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-27 05:36:31 [core.py:396] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.26 GiB. GPU 0 has a total capacity of 79.18 GiB of which 334.38 MiB is free. Process 23224 has 49.13 GiB memory in use. Process 23208 has 29.72 GiB memory in use. Of the allocated memory 48.53 GiB is allocated by PyTorch, and 102.21 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Traceback (most recent call last):
File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 400, in run_engine_core
raise e
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 387, in run_engine_core
engine_core = EngineCoreProc(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 329, in __init__
super().__init__(vllm_config, executor_class, log_stats,
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 71, in __init__
self._initialize_kv_caches(vllm_config)
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 156, in _initialize_kv_caches
self.model_executor.initialize_from_config(kv_cache_configs)
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/abstract.py", line 63, in initialize_from_config
self.collective_rpc("initialize_from_config",
File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
answer = run_method(self.driver_worker, method, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2456, in run_method
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 599, in initialize_from_config
self.worker.initialize_from_config(kv_cache_config) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 226, in initialize_from_config
self.model_runner.initialize_kv_cache(kv_cache_config)
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1722, in initialize_kv_cache
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.26 GiB. GPU 0 has a total capacity of 79.18 GiB of which 334.38 MiB is free. Process 23224 has 49.13 GiB memory in use. Process 23208 has 29.72 GiB memory in use. Of the allocated memory 48.53 GiB is allocated by PyTorch, and 102.21 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank0]:[W1027 05:36:32.751792988 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Traceback (most recent call last):
File "/usr/local/bin/vllm", line 10, in <module>
sys.exit(main())
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/cli/main.py", line 53, in main
args.dispatch_function(args)
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/cli/serve.py", line 27, in cmd
uvloop.run(run_server(args))
File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 109, in run
return __asyncio.run(
^^^^^^^^^^^^^^
File "/usr/lib/python3.12/asyncio/runners.py", line 195, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 61, in wrapper
return await main
^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 1078, in run_server
async with build_async_engine_client(args) as engine_client:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
return await anext(self.gen)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 146, in build_async_engine_client
async with build_async_engine_client_from_engine_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
return await anext(self.gen)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 178, in build_async_engine_client_from_engine_args
async_llm = AsyncLLM.from_vllm_config(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 150, in from_vllm_config
return cls(
^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/async_llm.py", line 118, in __init__
self.engine_core = core_client_class(
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 642, in __init__
super().__init__(
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 398, in __init__
self._wait_for_engine_startup()
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core_client.py", line 430, in _wait_for_engine_startup
raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above.
Hi @inforly, can you also do this? export KVCACHED_AUTOPATCH=1? I tested it on my end, it could work with the latest fix.
The example code might be a bit outdated. I just updated it in PR https://github.com/ovg-project/kvcached/pull/196.
Thanks, @jiarong0907 ! It started successfully this time. I tried with the export KVCACHED_AUTOPATCH=1, but looks no performance gain. Could you please take a look?
script
apt-get update && apt-get install -y git
git clone https://github.com/ovg-project/kvcached.git
cd kvcached
pip install -e . --no-build-isolation --no-cache-dir
python3 tools/dev_copy_pth.py
export ENABLE_KVCACHED=true
export KVCACHED_AUTOPATCH=1
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
nohup vllm serve Qwen/Qwen3-0.6B --port 8887 --host 0.0.0.0 --disable-log-requests --no-enable-prefix-caching > 1.log 2>&1 &
vllm serve Qwen/Qwen3-0.6B --port 8888 --host 0.0.0.0 --disable-log-requests --no-enable-prefix-caching
send requests to the 2 instances simultaneously
vllm bench serve --model Qwen/Qwen3-0.6B --request-rate 10 --num-prompts 1000 --port 8887
vllm bench serve --model Qwen/Qwen3-0.6B --request-rate 10 --num-prompts 1000 --port 8888
results without kvcached
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 103.61
Total input tokens: 1024000
Total generated tokens: 118882
Request throughput (req/s): 9.65
Output token throughput (tok/s): 1147.35
Total Token throughput (tok/s): 11030.14
---------------Time to First Token----------------
Mean TTFT (ms): 38.62
Median TTFT (ms): 37.00
P99 TTFT (ms): 63.26
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 12.09
Median TPOT (ms): 12.19
P99 TPOT (ms): 15.45
---------------Inter-token Latency----------------
Mean ITL (ms): 12.10
Median ITL (ms): 10.69
P99 ITL (ms): 28.67
==================================================
results with kvcached
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 104.52
Total input tokens: 1024000
Total generated tokens: 119010
Request throughput (req/s): 9.57
Output token throughput (tok/s): 1138.59
Total Token throughput (tok/s): 10935.39
---------------Time to First Token----------------
Mean TTFT (ms): 39.04
Median TTFT (ms): 37.50
P99 TTFT (ms): 67.98
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 12.54
Median TPOT (ms): 12.49
P99 TPOT (ms): 17.57
---------------Inter-token Latency----------------
Mean ITL (ms): 12.53
Median ITL (ms): 10.83
P99 ITL (ms): 31.69
==================================================
Great to see it runs! I assume you set a static gpu_mem_utilization (e.g., 0.4 for each model) when disabling kvcached. If so, similar performance is expected, because the main goal of kvcached is to allow multiple models to run together and share GPU memory for KV caches rather than static reservation. And since here the load is low and has not saturated all GPU memory, kvcached's dynamic coordination does not make a significant difference.
The performance difference will be more significant when GPU memory becomes the bottleneck. For example, I would suggest serving longer requests with higher request rates to saturate all GPU memory. And you can have imbalanced loads for two models, say one has 3x request rate than the other one, in which case static gpu_mem_utilization is suboptimal but kvcached can adaptively adjust their KV cache sizes to reflect this 3:1 ratio.
@ivanium thanks a lot for the detailed explanation! Actually, without kvcached, we couldn't even start the 2 instances successfully each time! I followed your suggestion to do the test, here are the results:
instance 1: vllm bench serve --model Qwen/Qwen3-0.6B --request-rate 10 --num-prompts 1000 --port 8887
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 104.08
Total input tokens: 1024000
Total generated tokens: 119007
Request throughput (req/s): 9.61
Output token throughput (tok/s): 1143.44
Total Token throughput (tok/s): 10982.20
---------------Time to First Token----------------
Mean TTFT (ms): 92.67
Median TTFT (ms): 47.83
P99 TTFT (ms): 1220.23
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 15.51
Median TPOT (ms): 15.19
P99 TPOT (ms): 24.52
---------------Inter-token Latency----------------
Mean ITL (ms): 15.98
Median ITL (ms): 14.86
P99 ITL (ms): 47.01
==================================================
instance 2: vllm bench serve --model Qwen/Qwen3-0.6B --request-rate 30 --num-promp ts 3000 --port 8888
============ Serving Benchmark Result ============
Successful requests: 3000
Benchmark duration (s): 137.72
Total input tokens: 3072000
Total generated tokens: 356802
Request throughput (req/s): 21.78
Output token throughput (tok/s): 2590.84
Total Token throughput (tok/s): 24897.54
---------------Time to First Token----------------
Mean TTFT (ms): 17214.56
Median TTFT (ms): 21294.43
P99 TTFT (ms): 32700.50
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 176.09
Median TPOT (ms): 197.33
P99 TPOT (ms): 239.17
---------------Inter-token Latency----------------
Mean ITL (ms): 175.83
Median ITL (ms): 171.37
P99 ITL (ms): 297.34
==================================================
We can even run 2 Qwen3-8B instances
instance 1: vllm bench serve --model Qwen/Qwen3-8B --request-rate 10 --num-prompts 1000 --port 8887
============ Serving Benchmark Result ============
Successful requests: 1000
Benchmark duration (s): 288.88
Total input tokens: 1024000
Total generated tokens: 123136
Request throughput (req/s): 3.46
Output token throughput (tok/s): 426.25
Total Token throughput (tok/s): 3970.97
---------------Time to First Token----------------
Mean TTFT (ms): 89074.74
Median TTFT (ms): 91861.29
P99 TTFT (ms): 181301.23
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 133.44
Median TPOT (ms): 134.52
P99 TPOT (ms): 200.68
---------------Inter-token Latency----------------
Mean ITL (ms): 127.67
Median ITL (ms): 51.46
P99 ITL (ms): 1317.77
==================================================
instance 2: vllm bench serve --model Qwen/Qwen3-8B --request-rate 30 --num-prompts 3000 --port 8888
============ Serving Benchmark Result ============
Successful requests: 3000
Benchmark duration (s): 524.33
Total input tokens: 3072000
Total generated tokens: 369034
Request throughput (req/s): 5.72
Output token throughput (tok/s): 703.82
Total Token throughput (tok/s): 6562.69
---------------Time to First Token----------------
Mean TTFT (ms): 245769.67
Median TTFT (ms): 278905.58
P99 TTFT (ms): 413392.21
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 262.62
Median TPOT (ms): 214.20
P99 TPOT (ms): 438.25
---------------Inter-token Latency----------------
Mean ITL (ms): 259.95
Median ITL (ms): 93.01
P99 ITL (ms): 1425.79
==================================================
At lower request rates (--request-rate 10) , the GPU utilization can greatly reduced with kvcached
Thanks again for the great work!
That's great trial and output with kvcached! Thanks for sharing the results.