[Bug]: Requests to `trtllm-serve serve` are executed sequentially if `max_tokens` is not provided
System Info
NVIDIA A100 80GB PCIe Driver Version: 570.172.08 CUDA Version: 12.8 Ubuntu 22.04 jammy
Container launched with following docker compose file:
services:
tensorrt-llm:
image: nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc2
container_name: tensorrt-llm-container
ports:
- "28001:8001"
volumes:
- ./models:/app/tensorrt_llm/models
- ./configs:/app/tensorrt_llm/configs
ipc: host
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["0", "1"]
capabilities: [gpu]
command: >
trtllm-serve serve ./models/gemma3-27b
--host 0.0.0.0
--port 8001
--max_batch_size 32
--tp_size 2
--log_level debug
--extra_llm_api_options ./configs/conf_fast.yaml
and command:
docker compose -f container_fast.yaml up -d
./configs/conf_fast.yaml:
attn_backend: "FLASHINFER"
Who can help?
No response
Information
- [ ] The official example scripts
- [x] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [x] My own task or dataset (give details below)
Reproduction
Reproduce with:
import asyncio
import time
from openai import AsyncOpenAI
from story.models.chat import Chat
from story.models.openai import Message
url = "https://api.example.com/v1"
chat = Chat(
messages=[
Message(role="system", content="You are a helpful assistant.\n"),
Message(role="user", content="Provide a list of the first 10 prime numbers."),
]
)
model_name = "google/gemma-3-27b-it"
client = AsyncOpenAI(base_url=url, api_key="your-api-key")
async def send_request(max_tokens: int | None) -> float:
args = {
"model": model_name,
"messages": chat.messages,
"stop": ["<end_of_turn>"],
}
if max_tokens:
args["max_tokens"] = max_tokens
start_time = time.time()
_ = await client.chat.completions.create(**args)
elapsed_time = time.time() - start_time
return elapsed_time
async def main():
for n_tokens in [None, 512]:
tasks = [send_request(max_tokens=n_tokens) for _ in range(3)]
results = await asyncio.gather(*tasks)
for i, elapsed in enumerate(results, 1):
print(f"Request {i} with max_tokens={n_tokens} completed in {elapsed:.2f} seconds")
print("------")
asyncio.run(main())
Produces following output:
Request 1 with max_tokens=None completed in 4.85 seconds
Request 2 with max_tokens=None completed in 7.13 seconds
Request 3 with max_tokens=None completed in 2.37 seconds
------
Request 1 with max_tokens=512 completed in 2.54 seconds
Request 2 with max_tokens=512 completed in 2.59 seconds
Request 3 with max_tokens=512 completed in 2.52 seconds
In the container logs the following lines can be found:
With max_tokens=None:
[TRT-LLM] [RANK 0] [V] has 3 active_request, scheduled 0 context requests and 1 generation requests
With max_tokens=512:
[TRT-LLM] [RANK 0] [V] has 3 active_request, scheduled 0 context requests and 3 generation requests
Expected behavior
Output is also processed concurrently when max_tokens is not provided.
Alternatively, a warning in the container would be nice.
actual behavior
Requests are executed in sequence when max_tokens is not provided.
additional notes
Please let me know if any additional information is needed.
Before submitting a new issue...
- [x] Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.
A random guess of mine is that trt-llm will assume a larger max sequence length for the input sequence, and this causes only 1 request to be scheduled instead of 3 requests which you expected.
In our application, we found that by default TRTLLM will use the GUARANTEED_NO_EVICT scheduler policy which allocates kv cache for the full request length up to max_tokens in each request, and blocks requests from starting if the max tokens would exceed the available memory. If max_tokens is not specified, the default could be large.
@Jensonah1 Can you try adding this to your extra-llm-options.yaml to enable the "max utilization" mode so less kv cache memory is needed to start responding to each request?
scheduler_config:
capacity_scheduler_policy: MAX_UTILIZATION
See some of the documentation here: https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/legacy/performance/performance-tuning-guide/useful-runtime-flags.md#capacity-scheduler-policy I'm not sure why this ended up in the legacy docs, but from what we can tell, the new TensorRT-LLM engine still uses the same capacity scheduler policy.
Thank you @pathorn! I can confirm that this works. I do remember having come across this before and also having tried it out. However, at the time it didn't seem to work out. My bad there I guess. Could I ask you to make sure this solution gets a (more prominent) place in the documentation? That would save other people like me from losing time :)
@Jensonah1 , thank you for raising the issue here~, and @pathorn, thank you for sharing your insights! Closing this issue based on the discussion above.