text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Out of Memory Errors When Running text-generation-benchmark Despite Compliant Batch Token Limit

Open martinigoyanes opened this issue 9 months ago • 8 comments

Environment

Runtime environment:

  • Target: x86_64-unknown-linux-gnu
  • Cargo version: 1.75.0
  • Commit sha: c38a7d7ddd9c612e368adec1ef94583be602fc7e
  • Docker label: sha-6c4496a

Kubernetes Cluster deployment

1 A100 GPU with 80GB RAM

12 CPU with 32 GB RAM

TGI version: 2.0.0

What I am doing

I am running text-generation-benchmark to find the sweet spot between throughput and latency for my hardware. I am trying to maximize the batch tokens by looking at the inferred MAX_BATCH_TOTAL_TOKENS by text-generation-launcher, however I get out of memory errors.

When running export LOG_LEVEL=INFO; text-generation-launcher --hostname 0.0.0.0 --port 8080 I see the MAX_BATCH_TOTAL_TOKENS inferred to be 425472.

2024-04-30T10:25:51.994120Z  INFO text_generation_launcher: Args { model_id: "/model_data/mistral7b-free", revision: None, validation_workers: 2, sharded: None, num_shard: None, quantize: None, speculate: None, dtype: None, trust_remote_code: false, max_concurrent_requests: 128, max_best_of: 2, max_stop_sequences: 4, max_top_n_tokens: 5, max_input_tokens: None, max_input_length: Some(8000), max_total_tokens: Some(8512), waiting_served_ratio: 0.3, max_batch_prefill_tokens: Some(32768), max_batch_total_tokens: Some(4294967295), max_waiting_tokens: 0, max_batch_size: None, cuda_graphs: None, hostname: "0.0.0.0", port: 8080, shard_uds_path: "/tmp/text-generation-server", master_addr: "localhost", master_port: 29500, huggingface_hub_cache: Some("/data"), weights_cache_override: None, disable_custom_kernels: false, cuda_memory_fraction: 1.0, rope_scaling: None, rope_factor: None, json_output: false, otlp_endpoint: None, cors_allow_origin: [], watermark_gamma: None, watermark_delta: None, ngrok: false, ngrok_authtoken: None, ngrok_edge: None, tokenizer_config_path: None, disable_grammar_support: false, env: false }
2024-04-30T10:25:51.994178Z  INFO text_generation_launcher: Model supports up to 32768 but tgi will now set its default to 4096 instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens=32818 --max-total-tokens=32768 --max-input-tokens=32767`.
2024-04-30T10:25:51.994184Z  INFO text_generation_launcher: Using default cuda graphs [1, 2, 4, 8, 16, 32]
2024-04-30T10:25:51.994271Z  INFO download: text_generation_launcher: Starting download process.
2024-04-30T10:25:54.856372Z  INFO text_generation_launcher: Files are already present on the host. Skipping download.

2024-04-30T10:25:55.330625Z  INFO download: text_generation_launcher: Successfully downloaded weights.
2024-04-30T10:25:55.330812Z  INFO shard-manager: text_generation_launcher: Starting shard rank=0
2024-04-30T10:26:05.338492Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
2024-04-30T10:26:15.433799Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
2024-04-30T10:26:17.323011Z  INFO text_generation_launcher: Server started at unix:///tmp/text-generation-server-0

2024-04-30T10:26:17.335204Z  INFO shard-manager: text_generation_launcher: Shard ready in 22.003836391s rank=0
2024-04-30T10:26:17.431088Z  INFO text_generation_launcher: Starting Webserver
2024-04-30T10:26:17.498225Z  INFO text_generation_router: router/src/main.rs:250: Using config Some(Mistral)
2024-04-30T10:26:17.498245Z  INFO text_generation_router: router/src/main.rs:257: Using local tokenizer config
2024-04-30T10:26:17.498263Z  WARN text_generation_router: router/src/main.rs:292: no pipeline tag found for model /model_data/mistral7b-free
2024-04-30T10:26:17.500561Z  INFO text_generation_router: router/src/main.rs:311: Warming up model
2024-04-30T10:26:20.987760Z  INFO text_generation_launcher: Cuda Graphs are enabled for sizes [1, 2, 4, 8, 16, 32]

2024-04-30T10:26:21.845520Z  WARN text_generation_router: router/src/main.rs:333: `--max-batch-total-tokens` is deprecated for Flash Attention models.
2024-04-30T10:26:21.845531Z  WARN text_generation_router: router/src/main.rs:337: Inferred max batch total tokens: 425472
2024-04-30T10:26:21.845534Z  INFO text_generation_router: router/src/main.rs:348: Setting max batch total tokens to 425472
2024-04-30T10:26:21.845536Z  INFO text_generation_router: router/src/main.rs:349: Connected

Therefore, even though text-generation-benchmark bypasses the router completely, I should be able to process 425472 tokens at the same time without running into out of memory errors right?

So I want to see the latency for this load type: 53 requests | 4000 sequence length | 4000 decode length -> 53 requests * (4000 in + 4000 out) = 424000 tokens concurrently. This is indeed lower that the inferred upper bound (424000 < 425472).

This is the command I am running: text-generation-benchmark --tokenizer-name /model_data/mistral7b-free/ -b 53 --sequence-length 4000 --decode-length 4000

What is the unexpected behavior

However, I get out of memory errors. Here the logs:

text-generation-launcher

024-04-30T10:20:12.210371Z  INFO text_generation_benchmark: benchmark/src/main.rs:138: Loading tokenizer
2024-04-30T10:20:12.210408Z  INFO text_generation_benchmark: benchmark/src/main.rs:144: Found local tokenizer
Model: /model_data/mistral7b-free/ | Sequence Length: 4000 | Decode Length: 4000                      <- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom
┌Tabs────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ Batch: 53                                                                                                                                                          │
└────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
┌Total Progress───────────────────────────────────────────────────────────────────┐┌Batch Progress───────────────────────────────────────────────────────────────────┐
│                                      0 / 1                                      ││                                      0 / 1                                      │
└─────────────────────────────────────────────────────────────────────────────────┘└─────────────────────────────────────────────────────────────────────────────────┘
┌                                      0 / 1                                      ┐┌                                      0 / 1                                      ┐
│Average: NaN ms                         ││                                       ││Average: NaN ms    ││Average: NaN ms    ││                                       │
│Lowest:  NaN ms                         ││                                       ││Lowest:  NaN ms    ││Lowest:  NaN ms    ││                                       │
│Highest: NaN ms                         ││                                       ││Highest: NaN ms    ││Highest: NaN ms    ││                                       │
│p50:     NaN ms                         ││                                       ││p50:     NaN ms    ││p50:     NaN ms    ││                                       │
│p90:     NaN ms                         ││                                       ││p90:     NaN ms    ││p90:     NaN ms    ││                                       │
│p99:     NaN ms                         ││                                       ││p99:     NaN ms    ││p99:     NaN ms    ││                                       │
└────────────────────────────────────────┘│                                       │└───────────────────┘└───────────────────┘│                                       │
┌Prefill Throughput──────────────────────┐│                                       │┌Decode Throughput───────────────────────┐│                                       │
│Average: NaN tokens/secs                ││                                       ││Average: NaN tokens/secs                ││                                       │
│Lowest:  NaN tokens/secs                ││                                       ││Lowest:  NaN tokens/secs                ││                                       │
│Highest: NaN tokens/secs                ││  NaN     NaN     NaN     NaN          ││Highest: NaN tokens/secs                ││  NaN     NaN     NaN     NaN          │
└────────────────────────────────────────┘└───────────────────────────────────────┘└────────────────────────────────────────┘└───────────────────────────────────────┘
┌Prefill throughput over latency──────────────────────────────────────────────────┐┌Decode throughput over latency───────────────────────────────────────────────────┐
│NaN │tokens/secs                                                                 ││NaN │tokens/secs                                                                 │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│NaN │                                                                            ││NaN │                                                                            │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│NaN │                                                                            ││NaN │                                                                            │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│NaN │                                                                            ││NaN │                                                                            │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│    │                                                                            ││    │                                                                            │
│0.00│                                                                          ms││0.00│                                                                          ms│
│    └────────────────────────────────────────────────────────────────────────────││    └────────────────────────────────────────────────────────────────────────────│
│ 0.00                      NaN            NaN            NaN                  NaN││ 0.00                      NaN            NaN            NaN                  NaN│
└─────────────────────────────────────────────────────────────────────────────────┘└─────────────────────────────────────────────────────────────────────────────────┘
2024-04-30T10:35:01.733254Z ERROR prefill{id=0 size=53}:prefill{id=0 size=53}: text_generation_client: router/client/src/lib.rs:33: Server error: CANCELLED

text-generation-launcher

 File "/opt/conda/lib/python3.10/site-packages/text_generation_server/utils/layers.py", line 159, in forward
    return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 11.32 GiB. GPU 0 has a total capacty of 79.14 GiB of which 3.98 GiB is free. Process 2218444 has 75.15 GiB memory in use. Of the allocated memory 73.66 GiB is allocated by PyTorch, and 963.77 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

nvidia-smi after OOM error

-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:17:00.0 Off |                    0 |
| N/A   46C    P0             87W /  300W |   76959MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

Extra notes

  • I have also tried with lower batch sizes: 50, 40, 32 and they also lead to an OOM error. I could get batch_size=2 to work though.
  • I have also tried to not bypass the router and send 60 concurrent requests with sequence+decode lengths = 8000 and this DOES WORK, so I do NOT understand why it does not work when bypassing the router if the router is only really preventing you from going over the MAX_BATCH_TOTAL_TOKENS which I am explicitly not going over when using text-generation-benchmark. What am I missing?

Grafana dashboard of tgi_batch_current_max_tokens when going through the router (423k tokens in a batch very close to the inferred MAX_BATCH_TOTAL_TOKENS) Screenshot 2024-04-30 at 12 42 32

martinigoyanes avatar Apr 30 '24 10:04 martinigoyanes

Same problem here as well.

fhkingma avatar Apr 30 '24 10:04 fhkingma

@Narsil do you have any thoughts on why this could be happening?

martinigoyanes avatar May 03 '24 09:05 martinigoyanes

@OlivierDehaene maybe you have some insights on this matter?

Thank you for your time!

martinigoyanes avatar May 06 '24 10:05 martinigoyanes

Hey @Venkat2811 , maybe you could enlighten me on this area? Would really appreciate it!

martinigoyanes avatar May 12 '24 13:05 martinigoyanes

Hey @martinigoyanes ,

Just taking a stab at the issue here. Without knowing actual model config, as per this article: The amount of GPU memory consumed scales with the base model size + the length of the token sequence.

Base model size: 14GB (2*7B param model), 80GB - 14GB = 66GB available for inference During pre-fill: 53 * 4k seq_len = 53 * 4 * 1 GB = 212 GB (as per article, 1 token requires ~1MB)

Resulting in error during prefill:

2024-04-30T10:35:01.733254Z ERROR prefill{id=0 size=53}:prefill{id=0 size=53}: text_generation_client: router/client/src/lib.rs:33: Server error: CANCELLED

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 11.32 GiB. GPU 0 has a total capacty of 79.14 GiB of which 3.98 GiB is free. Process 2218444 has 75.15 GiB memory in use. Of the allocated memory 73.66 GiB is allocated by PyTorch, and 963.77 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

This worked for you:

  • I have also tried with lower batch sizes: 50, 40, 32 and they also lead to an OOM error. I could get batch_size=2 to work though.

With 66GB available for inference Mem of pre-fill + decode: 2 * 4k seq_len + 4k output_len = 2 * 8 * 1GB = 16GB

Did you try batch size of 3 and 4 ? it should be possible i think.

  • I have also tried to not bypass the router and send 60 concurrent requests with sequence+decode lengths = 8000 and this DOES WORK,

From graph, it took ~10mins. If this was the time to complete inference of 60 requests, I think the above math holds. i.e., concurrent request processing of 2-4 req per batch.

I haven't looked into tgi_batch_current_max_tokens & MAX_BATCH_TOTAL_TOKENS yet.

Venkat2811 avatar May 13 '24 09:05 Venkat2811

Thanks for the response, yeah the math makes sense, I completely agree. My problem is that I also did the math but then TGI tells me MAX_BATCH_TOTAL_TOKENS= 425472 while then if you try to do so many tokens you actually get OOM errors. If 1 token ~ 1MB then that implies I need 425k MB which is like 425 GB of free VRAM.

While in theory MAX_BATCH_TOTAL_TOKENS are all the tokens that can fit in a batch to the LLM. See https://github.com/huggingface/text-generation-inference/blob/d348d2b28feeaab7a8f6bd44cc8924b6b4ae7868/router/src/queue.rs#L185

Also for reference:

 --max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>
          **IMPORTANT** This is one critical control to allow maximum usage of the available hardware.
          
          This represents the total amount of potential tokens within a batch. When using padding (not recommended) this would be equivalent of `batch_size` * `max_total_tokens`.
          
          However in the non-padded (flash attention) version this can be much finer.
          
          For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` or a single query of `1000` tokens.
          
          Overall this number should be the largest possible amount that fits the remaining memory (after the model is loaded). Since the actual memory overhead depends on other parameters like if you're using quantization, flash attention or the model implementation, text-generation-inference cannot infer this number automatically.

martinigoyanes avatar May 13 '24 14:05 martinigoyanes

maybe related https://github.com/huggingface/text-generation-inference/issues/1286

fxmarty avatar May 16 '24 08:05 fxmarty

+1

fxmarty avatar May 16 '24 17:05 fxmarty

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar Jun 17 '24 01:06 github-actions[bot]

In the launcher, one needs to set the --max-batch-prefill-tokens option to the maximum total tokens (batch_size * prompt_length) that will be benchmarked with the benchmark tool. Essentially, the scheduler is bypassed and not enough memory is reserved for intermediary buffers otherwise.

fxmarty avatar Jun 17 '24 10:06 fxmarty

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar Jul 19 '24 01:07 github-actions[bot]