text-generation-inference
text-generation-inference copied to clipboard
Odd CUDA OOM
System Info
Using latest version of TGI on EKS with a single A10G GPU loading the flan-t5-xl model without quantization.
When setting token limits to the following, then it loads and after warmup it uses 9.5Gb VRAM on the GPU. Token limits:
- name: MAX_INPUT_LENGTH
value: "4000"
- name: MAX_BATCH_PREFILL_TOKENS
value: "4100"
- name: MAX_BATCH_TOTAL_TOKENS
value: "4400"
- name: MAX_TOTAL_TOKENS
value: "4400"
GPU RAM usage after warmup:
$ nvidia-smi
Fri Jul 7 17:58:29 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03 Driver Version: 470.182.03 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 |
| 0% 36C P0 64W / 300W | 9542MiB / 22731MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
+-----------------------------------------------------------------------------+
If I increase the token limit just by 10%, then suddenly I am out of GPU memory? Slightly increased token limit:
- name: MAX_INPUT_LENGTH
value: "4500"
- name: MAX_BATCH_PREFILL_TOKENS
value: "4600"
- name: MAX_BATCH_TOTAL_TOKENS
value: "4900"
- name: MAX_TOTAL_TOKENS
value: "4900"
Error thrown:
2023-07-07T17:53:15.222817Z INFO text_generation_launcher: Args { model_id: "google/flan-t5-xl", revision: None, sharded: Some(false), num_shard: None, quantize: None, dtype: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_input_length: 4500, max_total_tokens: 4900, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4600, max_batch_total_tokens: 4900, max_waiting_tokens: 20, hostname: "flan-t5-xl-65b5947f77-d79l8", port: 80, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_domain: None, ngrok_username: None, ngrok_password: None, env: false }
2023-07-07T17:53:15.222927Z INFO text_generation_launcher: Starting download process.
2023-07-07T17:53:17.000298Z WARN download: text_generation_launcher: No safetensors weights found for model google/flan-t5-xl at revision None. Downloading PyTorch weights.
2023-07-07T17:53:17.059103Z INFO download: text_generation_launcher: Download file: pytorch_model-00001-of-00002.bin
2023-07-07T17:53:35.885816Z INFO download: text_generation_launcher: Downloaded /data/models--google--flan-t5-xl/snapshots/53fd1e22aa944eee1fd336f9aee8a437e01676ce/pytorch_model-00001-of-00002.bin in 0:00:18.
2023-07-07T17:53:35.885872Z INFO download: text_generation_launcher: Download: [1/2] -- ETA: 0:00:18
2023-07-07T17:53:35.886112Z INFO download: text_generation_launcher: Download file: pytorch_model-00002-of-00002.bin
2023-07-07T17:53:42.973060Z INFO download: text_generation_launcher: Downloaded /data/models--google--flan-t5-xl/snapshots/53fd1e22aa944eee1fd336f9aee8a437e01676ce/pytorch_model-00002-of-00002.bin in 0:00:07.
2023-07-07T17:53:42.973144Z INFO download: text_generation_launcher: Download: [2/2] -- ETA: 0
2023-07-07T17:53:42.973236Z WARN download: text_generation_launcher: No safetensors weights found for model google/flan-t5-xl at revision None. Converting PyTorch weights to safetensors.
2023-07-07T17:55:10.575458Z INFO download: text_generation_launcher: Convert: [1/2] -- Took: 0:01:27.308986
2023-07-07T17:55:21.870011Z INFO download: text_generation_launcher: Convert: [2/2] -- Took: 0:00:11.302132
2023-07-07T17:55:23.385415Z INFO text_generation_launcher: Successfully downloaded weights.
2023-07-07T17:55:23.386189Z INFO text_generation_launcher: Starting shard 0
2023-07-07T17:55:28.399204Z WARN shard-manager: text_generation_launcher: We're not using custom kernels.
rank=0
2023-07-07T17:55:33.396703Z INFO text_generation_launcher: Waiting for shard 0 to be ready...
2023-07-07T17:55:40.689392Z INFO shard-manager: text_generation_launcher: Server started at unix:///tmp/text-generation-server-0
rank=0
2023-07-07T17:55:40.703506Z INFO text_generation_launcher: Shard 0 ready in 17.315447111s
2023-07-07T17:55:40.801254Z INFO text_generation_launcher: Starting Webserver
2023-07-07T17:55:41.130302Z INFO text_generation_router: router/src/main.rs:208: Warming up model
2023-07-07T17:55:43.686576Z ERROR shard-manager: text_generation_launcher: Method Warmup encountered an error.
Traceback (most recent call last):
File "/opt/conda/bin/text-generation-server", line 8, in <module>
sys.exit(app())
File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 311, in __call__
return get_command(self)(*args, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1130, in __call__
return self.main(*args, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 778, in main
return _main(
File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 216, in _main
rv = self.invoke(ctx)
File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1657, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1404, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 760, in invoke
return __callback(*args, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 683, in wrapper
return callback(**use_params) # type: ignore
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 78, in serve
server.serve(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 166, in serve
asyncio.run(
File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 634, in run_until_complete
self.run_forever()
File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
self._run_once()
File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
handle._run()
File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/opt/conda/lib/python3.9/site-packages/grpc_interceptor/server.py", line 159, in invoke_intercept_method
return await self.intercept(
> File "/opt/conda/lib/python3.9/site-packages/text_generation_server/interceptor.py", line 20, in intercept
return await response
File "/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 82, in _unary_interceptor
raise error
File "/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 73, in _unary_interceptor
return await behavior(request_or_iterator, context)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 60, in Warmup
self.model.warmup(batch, request.max_total_tokens)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/model.py", line 62, in warmup
self.generate_token(batch)
File "/opt/conda/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/seq2seq_lm.py", line 607, in generate_token
logits, encoder_last_hidden_state, past = self.forward(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/t5.py", line 88, in forward
outputs = self.model.forward(
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 1070, in forward
encoder_outputs = self.encoder(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 931, in forward
layer_outputs = layer_module(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 639, in forward
self_attention_outputs = self.layer[0](
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 515, in forward
attention_output = self.SelfAttention(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 462, in forward
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/functional.py", line 1843, in softmax
ret = input.softmax(dim)
torch.cuda.OutOfMemoryError: Allocation on device 0 would exceed allowed memory. (out of memory)
Currently allocated : 16.08 GiB
Requested : 4.83 GiB
Device limit : 22.20 GiB
Free (according to CUDA): 11.12 MiB
PyTorch limit (set by user-supplied memory fraction)
: 22.20 GiB
rank=0
2023-07-07T17:55:43.686922Z ERROR warmup{max_input_length=4500 max_prefill_tokens=4600 max_total_tokens=4900}:warmup{max_input_length=4500 max_prefill_tokens=4600 max_total_tokens=4900}: text_generation_client: router/client/src/lib.rs:33: Server error: Allocation on device 0 would exceed allowed memory. (out of memory)
Currently allocated : 16.08 GiB
Requested : 4.83 GiB
Device limit : 22.20 GiB
Free (according to CUDA): 11.12 MiB
PyTorch limit (set by user-supplied memory fraction)
: 22.20 GiB
thread 'main' panicked at 'Unable to warmup model: Generation("Allocation on device 0 would exceed allowed memory. (out of memory)\nCurrently allocated : 16.08 GiB\nRequested : 4.83 GiB\nDevice limit : 22.20 GiB\nFree (according to CUDA): 11.12 MiB\nPyTorch limit (set by user-supplied memory fraction)\n : 22.20 GiB")', router/src/main.rs:216:18
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
2023-07-07T17:55:43.805488Z ERROR text_generation_launcher: Webserver Crashed
2023-07-07T17:55:43.805514Z INFO text_generation_launcher: Shutting down shards
2023-07-07T17:55:44.134207Z INFO text_generation_launcher: Shard 0 terminated
Error: WebserverFailed
This is very odd.
I am sure that memory usage is not linear, but this doesn't seem to be right. Is it possible that there is an initial spike in GPU memory utilization after which it settles down at a lower number and that's why I see that low number reported by nvidia-smi after warmup completed?
Ps.: Going up in the token limit from ~3000 to ~4000 has only increased the reported GPU memory utilization from about 9.1Gb to 9.5Gb.
Information
- [ ] Docker
- [ ] The CLI directly
Tasks
- [ ] An officially supported command
- [ ] My own modifications
Reproduction
- Load the flan-t5-xl model without quanization on an A10G GPU with the following token limits:
- name: MAX_INPUT_LENGTH
value: "4000"
- name: MAX_BATCH_PREFILL_TOKENS
value: "4100"
- name: MAX_BATCH_TOTAL_TOKENS
value: "4400"
- name: MAX_TOTAL_TOKENS
value: "4400"
Expected behavior
GPU memory utilization should be more linear - or explanation provided regarding an initial spike in GPU memory usage during warmup.
The Warmup phase ( the one crashing) is trying to allocate the MAXIMUM possible request mimicking your server under load.
text_generation_launcher: Method Warmup encountered an error.
We try to provide a better error, but I'm guessing the system is failing here. @OlivierDehaene shouldn't the warmup errors be catched and raised into nicer to read errors ? (I think the mecanism already exist but I don't get why it's not working here).
Thanks @Narsil, so as I suspected the warmup phase generates a memory spike.
But is this a temporary spike only? Meaning after the completion of the warmup GPU memory being freed up?
That is what I seem to be seeing - the reported memory usage by nvidia-smi seems to be well below the GPU's max memory after a successful warmup with 4000 tokens (and can't completed the warmup with 4500 tokens as it runs out of memory).
Yes, in general though PyTorch will allocate memory however it likes so reports by nvidia-smi might not really reflect whatever is actually necessary.
Got the same issue with the official docker v0.9.3. I'm using AWS EC2 g5 instance with A10G GPU and 24 GB DRAM. The model I'm using is MPT-7B-8k. Here are the related configs:
--max-total-tokens 8192
--max-input-length 8064 # 8192 - 128
--max-batch-total-tokens 8192
--max-batch-prefill-tokens 8064
--num-shard 4
This results in OOM during warmup. Meanwhile, I also tried a shorter sequence length 6144 and it worked. After it is ready, nvidia-smi shows the memory consumption is just ~9GB per GPU. Given that MPT-7B has 32 layers and 4096 hidden size, sharding it to 4 GPUs with 24GB DRAM each should be sufficient for sequence length 8192. Did I miss anything?
@comaniac, I am still using the 4k version with a single A10G and that works, but I believe there was a bugfix related to this which was merged after v0.9.3, so try v0.9.4 (which I use with the 4k model).
I've actually tried v1.0.0 but no luck. My guess is the vLLM cache calculation doesn't really consider the sequence length and may allocate too many memory blocks to kv-cache, which results in OOM when attempting to allocate memory to activations during warmup.
0.9.3 had issues, because we were using AyncMalloc, and it seems PyTorch doesn´t do a great job at tracking those allocations leading to all sorts of issues everywhere, we did roll back for 0.9.4 (I'm not sure about the versions but it sounds right).
@comaniac vLLM uses blocks, so it can definitely use more memory at max usage than the theoretical max (1 token that creates a new block will create 128 slots for every request).
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.