text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Odd CUDA OOM

Open zoltan-fedor opened this issue 2 years ago • 7 comments

System Info

Using latest version of TGI on EKS with a single A10G GPU loading the flan-t5-xl model without quantization.

When setting token limits to the following, then it loads and after warmup it uses 9.5Gb VRAM on the GPU. Token limits:

      - name: MAX_INPUT_LENGTH
        value: "4000" 
      - name: MAX_BATCH_PREFILL_TOKENS
        value: "4100"
      - name: MAX_BATCH_TOTAL_TOKENS
        value: "4400"
      - name: MAX_TOTAL_TOKENS
        value: "4400"

GPU RAM usage after warmup:

$ nvidia-smi
Fri Jul  7 17:58:29 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A10G         On   | 00000000:00:1E.0 Off |                    0 |
|  0%   36C    P0    64W / 300W |   9542MiB / 22731MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

If I increase the token limit just by 10%, then suddenly I am out of GPU memory? Slightly increased token limit:

      - name: MAX_INPUT_LENGTH
        value: "4500"
      - name: MAX_BATCH_PREFILL_TOKENS
        value: "4600"
      - name: MAX_BATCH_TOTAL_TOKENS
        value: "4900"
      - name: MAX_TOTAL_TOKENS
        value: "4900"

Error thrown:

2023-07-07T17:53:15.222817Z  INFO text_generation_launcher: Args { model_id: "google/flan-t5-xl", revision: None, sharded: Some(false), num_shard: None, quantize: None, dtype: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_input_length: 4500, max_total_tokens: 4900, waiting_served_ratio: 1.2, max_batch_prefill_tokens: 4600, max_batch_total_tokens: 4900, max_waiting_tokens: 20, hostname: "flan-t5-xl-65b5947f77-d79l8", port: 80, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_domain: None, ngrok_username: None, ngrok_password: None, env: false }
2023-07-07T17:53:15.222927Z  INFO text_generation_launcher: Starting download process.
2023-07-07T17:53:17.000298Z  WARN download: text_generation_launcher: No safetensors weights found for model google/flan-t5-xl at revision None. Downloading PyTorch weights.

2023-07-07T17:53:17.059103Z  INFO download: text_generation_launcher: Download file: pytorch_model-00001-of-00002.bin

2023-07-07T17:53:35.885816Z  INFO download: text_generation_launcher: Downloaded /data/models--google--flan-t5-xl/snapshots/53fd1e22aa944eee1fd336f9aee8a437e01676ce/pytorch_model-00001-of-00002.bin in 0:00:18.

2023-07-07T17:53:35.885872Z  INFO download: text_generation_launcher: Download: [1/2] -- ETA: 0:00:18

2023-07-07T17:53:35.886112Z  INFO download: text_generation_launcher: Download file: pytorch_model-00002-of-00002.bin

2023-07-07T17:53:42.973060Z  INFO download: text_generation_launcher: Downloaded /data/models--google--flan-t5-xl/snapshots/53fd1e22aa944eee1fd336f9aee8a437e01676ce/pytorch_model-00002-of-00002.bin in 0:00:07.

2023-07-07T17:53:42.973144Z  INFO download: text_generation_launcher: Download: [2/2] -- ETA: 0

2023-07-07T17:53:42.973236Z  WARN download: text_generation_launcher: No safetensors weights found for model google/flan-t5-xl at revision None. Converting PyTorch weights to safetensors.

2023-07-07T17:55:10.575458Z  INFO download: text_generation_launcher: Convert: [1/2] -- Took: 0:01:27.308986

2023-07-07T17:55:21.870011Z  INFO download: text_generation_launcher: Convert: [2/2] -- Took: 0:00:11.302132

2023-07-07T17:55:23.385415Z  INFO text_generation_launcher: Successfully downloaded weights.
2023-07-07T17:55:23.386189Z  INFO text_generation_launcher: Starting shard 0
2023-07-07T17:55:28.399204Z  WARN shard-manager: text_generation_launcher: We're not using custom kernels.
 rank=0
2023-07-07T17:55:33.396703Z  INFO text_generation_launcher: Waiting for shard 0 to be ready...
2023-07-07T17:55:40.689392Z  INFO shard-manager: text_generation_launcher: Server started at unix:///tmp/text-generation-server-0
 rank=0
2023-07-07T17:55:40.703506Z  INFO text_generation_launcher: Shard 0 ready in 17.315447111s
2023-07-07T17:55:40.801254Z  INFO text_generation_launcher: Starting Webserver
2023-07-07T17:55:41.130302Z  INFO text_generation_router: router/src/main.rs:208: Warming up model
2023-07-07T17:55:43.686576Z ERROR shard-manager: text_generation_launcher: Method Warmup encountered an error.
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.9/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.9/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/cli.py", line 78, in serve
    server.serve(
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 166, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.9/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 634, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.9/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/conda/lib/python3.9/site-packages/grpc_interceptor/server.py", line 159, in invoke_intercept_method
    return await self.intercept(
> File "/opt/conda/lib/python3.9/site-packages/text_generation_server/interceptor.py", line 20, in intercept
    return await response
  File "/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 82, in _unary_interceptor
    raise error
  File "/opt/conda/lib/python3.9/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 73, in _unary_interceptor
    return await behavior(request_or_iterator, context)
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 60, in Warmup
    self.model.warmup(batch, request.max_total_tokens)
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/model.py", line 62, in warmup
    self.generate_token(batch)
  File "/opt/conda/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/seq2seq_lm.py", line 607, in generate_token
    logits, encoder_last_hidden_state, past = self.forward(
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/t5.py", line 88, in forward
    outputs = self.model.forward(
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 1070, in forward
    encoder_outputs = self.encoder(
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 931, in forward
    layer_outputs = layer_module(
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 639, in forward
    self_attention_outputs = self.layer[0](
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 515, in forward
    attention_output = self.SelfAttention(
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/custom_modeling/t5_modeling.py", line 462, in forward
    attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/functional.py", line 1843, in softmax
    ret = input.softmax(dim)
torch.cuda.OutOfMemoryError: Allocation on device 0 would exceed allowed memory. (out of memory)
Currently allocated     : 16.08 GiB
Requested               : 4.83 GiB
Device limit            : 22.20 GiB
Free (according to CUDA): 11.12 MiB
PyTorch limit (set by user-supplied memory fraction)
                        : 22.20 GiB
 rank=0
2023-07-07T17:55:43.686922Z ERROR warmup{max_input_length=4500 max_prefill_tokens=4600 max_total_tokens=4900}:warmup{max_input_length=4500 max_prefill_tokens=4600 max_total_tokens=4900}: text_generation_client: router/client/src/lib.rs:33: Server error: Allocation on device 0 would exceed allowed memory. (out of memory)
Currently allocated     : 16.08 GiB
Requested               : 4.83 GiB
Device limit            : 22.20 GiB
Free (according to CUDA): 11.12 MiB
PyTorch limit (set by user-supplied memory fraction)
                        : 22.20 GiB
thread 'main' panicked at 'Unable to warmup model: Generation("Allocation on device 0 would exceed allowed memory. (out of memory)\nCurrently allocated     : 16.08 GiB\nRequested               : 4.83 GiB\nDevice limit            : 22.20 GiB\nFree (according to CUDA): 11.12 MiB\nPyTorch limit (set by user-supplied memory fraction)\n                        : 22.20 GiB")', router/src/main.rs:216:18
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
2023-07-07T17:55:43.805488Z ERROR text_generation_launcher: Webserver Crashed
2023-07-07T17:55:43.805514Z  INFO text_generation_launcher: Shutting down shards
2023-07-07T17:55:44.134207Z  INFO text_generation_launcher: Shard 0 terminated
Error: WebserverFailed

This is very odd. I am sure that memory usage is not linear, but this doesn't seem to be right. Is it possible that there is an initial spike in GPU memory utilization after which it settles down at a lower number and that's why I see that low number reported by nvidia-smi after warmup completed?

Ps.: Going up in the token limit from ~3000 to ~4000 has only increased the reported GPU memory utilization from about 9.1Gb to 9.5Gb.

Information

  • [ ] Docker
  • [ ] The CLI directly

Tasks

  • [ ] An officially supported command
  • [ ] My own modifications

Reproduction

  1. Load the flan-t5-xl model without quanization on an A10G GPU with the following token limits:
      - name: MAX_INPUT_LENGTH
        value: "4000" 
      - name: MAX_BATCH_PREFILL_TOKENS
        value: "4100"
      - name: MAX_BATCH_TOTAL_TOKENS
        value: "4400"
      - name: MAX_TOTAL_TOKENS
        value: "4400"

Expected behavior

GPU memory utilization should be more linear - or explanation provided regarding an initial spike in GPU memory usage during warmup.

zoltan-fedor avatar Jul 07 '23 18:07 zoltan-fedor

The Warmup phase ( the one crashing) is trying to allocate the MAXIMUM possible request mimicking your server under load.

text_generation_launcher: Method Warmup encountered an error.

We try to provide a better error, but I'm guessing the system is failing here. @OlivierDehaene shouldn't the warmup errors be catched and raised into nicer to read errors ? (I think the mecanism already exist but I don't get why it's not working here).

Narsil avatar Jul 09 '23 08:07 Narsil

Thanks @Narsil, so as I suspected the warmup phase generates a memory spike.

But is this a temporary spike only? Meaning after the completion of the warmup GPU memory being freed up? That is what I seem to be seeing - the reported memory usage by nvidia-smi seems to be well below the GPU's max memory after a successful warmup with 4000 tokens (and can't completed the warmup with 4500 tokens as it runs out of memory).

zoltan-fedor avatar Jul 09 '23 13:07 zoltan-fedor

Yes, in general though PyTorch will allocate memory however it likes so reports by nvidia-smi might not really reflect whatever is actually necessary.

Narsil avatar Jul 09 '23 14:07 Narsil

Got the same issue with the official docker v0.9.3. I'm using AWS EC2 g5 instance with A10G GPU and 24 GB DRAM. The model I'm using is MPT-7B-8k. Here are the related configs:

--max-total-tokens 8192
--max-input-length 8064 # 8192 - 128
--max-batch-total-tokens 8192
--max-batch-prefill-tokens 8064
--num-shard 4

This results in OOM during warmup. Meanwhile, I also tried a shorter sequence length 6144 and it worked. After it is ready, nvidia-smi shows the memory consumption is just ~9GB per GPU. Given that MPT-7B has 32 layers and 4096 hidden size, sharding it to 4 GPUs with 24GB DRAM each should be sufficient for sequence length 8192. Did I miss anything?

comaniac avatar Aug 01 '23 21:08 comaniac

@comaniac, I am still using the 4k version with a single A10G and that works, but I believe there was a bugfix related to this which was merged after v0.9.3, so try v0.9.4 (which I use with the 4k model).

zoltan-fedor avatar Aug 01 '23 22:08 zoltan-fedor

I've actually tried v1.0.0 but no luck. My guess is the vLLM cache calculation doesn't really consider the sequence length and may allocate too many memory blocks to kv-cache, which results in OOM when attempting to allocate memory to activations during warmup.

comaniac avatar Aug 01 '23 22:08 comaniac

0.9.3 had issues, because we were using AyncMalloc, and it seems PyTorch doesn´t do a great job at tracking those allocations leading to all sorts of issues everywhere, we did roll back for 0.9.4 (I'm not sure about the versions but it sounds right).

@comaniac vLLM uses blocks, so it can definitely use more memory at max usage than the theoretical max (1 token that creates a new block will create 128 slots for every request).

Narsil avatar Aug 03 '23 08:08 Narsil

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar May 06 '24 01:05 github-actions[bot]