vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Misc]: Throughput/Latency for guided_json with ~100% GPU cache utilization

Open jens-create opened this issue 11 months ago • 43 comments

Anything you want to discuss about vllm.

Hi,

I am running some benchmarks on the vllm.entrypoints.openai.api_server measuring latency and throughput with different number of concurrent requests.

Specs:

  • H100 80GB
  • qwen-1.5-14B-chat

I am sending 1000 requests with random prompts of token length 512. These are the results I get (see attached image):

Guided_json

  • ~100 running requests
  • ~70 generation tokens per second
  • ~1700 ms median token time

Non-guided_json

  • ~100 running requests
  • ~800 generation tokens per second
  • ~75 ms median token time (TPOT)

At 10 concurrent request (GPU utlization << 100%

Non-guided_json: ~20 ms median token time guided_json: ~ 160 ms median token time

Currently the application I am building heavily relies on guided_json, however, to put it in an online setting I would like to ask 1) are the numbers I experience sensible and 2) what can be done to improve performance in the guided_json paradigm?

I am debating whether I should try and prompt my way to structured outputs and thus avoiding constrained decoding.

Screenshot 2024-03-22 at 10 10 14 )

jens-create avatar Mar 22 '24 12:03 jens-create

Is the JSON schema complex at all, and is it the same each time? The 70 toks/s number is a bit lower than I expected. This can be due to several factor if it's the same schema:

  • Currently the logits mask computation is performed on the critical path but it can be moved earlier.
  • We currently don't batch the application of logits processors.
  • Python overhead in general

I'm interested in fixing the performance here.

simon-mo avatar Mar 22 '24 16:03 simon-mo

Hi Simon,

The JSON schema is the same at all times, and it is as follows:

"guided_json": {"$defs": {"SearchQuery": {"description": "Search query for the retrieval task.", "properties": {"query_type": {"description": "The type of query most effective for handling the retrieval task.", "title": "Query Type", "type": "string"}, "query": {"description": "A random user's search query.", "title": "Query", "type": "string"}}, "required": ["query_type", "query"], "title": "SearchQuery", "type": "object"}}, "description": "A list of search queries anticipating a user looking for information from a given web page.", "properties": {"queries": {"description": "Brainstormed search queries for the given web page.", "items": {"$ref": "#/$defs/SearchQuery"}, "title": "Queries", "type": "array"}}, "required": ["queries"], "title": "Brainstorm", "type": "object"}

Thanks for looking into this 🫶

jens-create avatar Mar 25 '24 10:03 jens-create

@simon-mo any update on this? 😊

jens-create avatar Apr 23 '24 09:04 jens-create

Facing similar issue here, I have a json with 14 fields, the request stucks forever.

taoisu avatar Apr 24 '24 00:04 taoisu

My schema only has 2 fields and also has significant latency issues than when using without guided_json. Would love to have this fixed as model performance severely decreases without it.

lithafnium avatar May 07 '24 00:05 lithafnium

I would suggest trying out setting --guided-decoding-backend lm-format-enforcer (through args) or "guided_decoding_backend": "lm-format-enforcer" as part of the request to see whether it helps. see original PR here: https://github.com/vllm-project/vllm/pull/3868 (cc @noamgat)

simon-mo avatar May 07 '24 03:05 simon-mo

If testing lm-format-enforcer, I highly recommend adding the latest version of it to the image, as there have been performance improvements to the JsonSchemaParser. The next version of vLLM will include them, but until them, do pip install lm-format-enforcer==0.10.1 in the image before testing.

noamgat avatar May 07 '24 07:05 noamgat

what speeds are you getting @noamgat vs the outlines backend?

jarrelscy avatar May 07 '24 07:05 jarrelscy

I didn't test on A100/H100s, but on my dev setup (GTX 3090, Mistral7B), for simple schemas, I was getting a less than 2x reduction of tokens/s.

noamgat avatar May 07 '24 10:05 noamgat

+1, it seems not GPU related, I tested with A100 / V100 GPUs both have similar issue.

Using line profiler, I found this get_guided_decoding_logits_processor call takes 93% time

nullpointer0xffff avatar May 07 '24 11:05 nullpointer0xffff

If testing lm-format-enforcer, I highly recommend adding the latest version of it to the image, as there have been performance improvements to the JsonSchemaParser. The next version of vLLM will include them, but until them, do pip install lm-format-enforcer==0.10.1 in the image before testing.

This get vllm 0.4.1+cu118 requires lm-format-encorcer==0.9.8, requiring to add --no-deps.

Just tested, the speed up is not obvious, probabbly the main bottleneck is still the get_guided_decoding_logits_processor

nullpointer0xffff avatar May 07 '24 12:05 nullpointer0xffff

If testing lm-format-enforcer, I highly recommend adding the latest version of it to the image, as there have been performance improvements to the JsonSchemaParser. The next version of vLLM will include them, but until them, do pip install lm-format-enforcer==0.10.1 in the image before testing.

This get vllm 0.4.1+cu118 requires lm-format-encorcer==0.9.8, requiring to add --no-deps.

Just tested, the speed up is not obvious (25tok/s -> 32 tok/s on V100), probabbly the main bottleneck is still the get_guided_decoding_logits_processor

nullpointer0xffff avatar May 07 '24 12:05 nullpointer0xffff

@noamgat here's a profling when I use lm-format-enforcer 0.10.1.

/lib/python3.10/site-packages/lmformatenforcer/integrations/transformers.py

Function: _build_regular_tokens_list at line 58

 

Line #      Hits         Time  Per Hit   % Time  Line Contents

==============================================================

    58                                           @profile

    59                                           def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str, bool]]:

    60         1  912794903.0    9e+08      9.5      token_0 = tokenizer.encode("0")[-1]

    61         1       8025.0   8025.0      0.0      regular_tokens = []

    62    128257   28050361.0    218.7      0.3      for token_idx in range(len(tokenizer)):

    63    128256   78294452.0    610.5      0.8          if token_idx in tokenizer.all_special_ids:

    64         2        450.0    225.0      0.0              continue

    65                                                   # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.

    66    128254 5319568501.0  41476.8     55.3          decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]

    67    128254 3162992335.0  24661.9     32.9          decoded_regular = tokenizer.decode([token_idx])

    68    128254   56427009.0    440.0      0.6          is_word_start_token = len(decoded_after_0) > len(decoded_regular)

    69    128254   61975079.0    483.2      0.6          regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))

70         1        240.0    240.0      0.0      return regular_tokens

The two decoding in for loop seems took most time. Happy to make further test if needed.

nullpointer0xffff avatar May 07 '24 17:05 nullpointer0xffff

Build regular token list only happens at the first request that uses LMFE. Does it happen every time? If so, maybe there is a problem with lru caching not working.

On Tue, May 7, 2024, 20:30 nullpointer0xffff @.***> wrote:

@noamgat https://github.com/noamgat here's a profling when I use lm-format-enforcer 0.10.1.

/lib/python3.10/site-packages/lmformatenforcer/integrations/transformers.py

Function: _build_regular_tokens_list at line 58

Line # Hits Time Per Hit % Time Line Contents

==============================================================

58                                           @profile

59                                           def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str, bool]]:

60         1  912794903.0    9e+08      9.5      token_0 = tokenizer.encode("0")[-1]

61         1       8025.0   8025.0      0.0      regular_tokens = []

62    128257   28050361.0    218.7      0.3      for token_idx in range(len(tokenizer)):

63    128256   78294452.0    610.5      0.8          if token_idx in tokenizer.all_special_ids:

64         2        450.0    225.0      0.0              continue

65                                                   # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.

66    128254 5319568501.0  41476.8     55.3          decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]

67    128254 3162992335.0  24661.9     32.9          decoded_regular = tokenizer.decode([token_idx])

68    128254   56427009.0    440.0      0.6          is_word_start_token = len(decoded_after_0) > len(decoded_regular)

69    128254   61975079.0    483.2      0.6          regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))

70 1 240.0 240.0 0.0 return regular_tokens

The two decoding in for loop seems took most time. Happy to make further test if needed.

— Reply to this email directly, view it on GitHub https://github.com/vllm-project/vllm/issues/3567#issuecomment-2098958346, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKFA2A6JFYKEQAHJKKQFTTZBEFTBAVCNFSM6AAAAABFDGVKP6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJYHE2TQMZUGY . You are receiving this because you were mentioned.Message ID: @.***>

noamgat avatar May 07 '24 19:05 noamgat

Just clarifying - if possible, start the tokens/s measuring and/or profiling from the second request onwards. While the warm-up time is also something that can be optimized, the post-warmup performance matters much more for real-world use cases. This is true for all guided decoding backends.

noamgat avatar May 08 '24 05:05 noamgat

@nullpointer0xffff @jens-create I just confirmed the caching of LMFE tokenizer init (very very slow) via @lru_cache is working so build_regular_tokens_list should never be called past the first request.

https://github.com/vllm-project/vllm/blob/63575bc2e197b85ce1c911421ff30c5459e35e9c/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py#L67

Qubitium avatar May 08 '24 07:05 Qubitium

maybe we can modify the call method to separate the mask computation from the logits adjustment. This allows the mask to be computed once and reused. let me know if this makes sense @simon-mo

SaloniGandhi avatar May 15 '24 01:05 SaloniGandhi

Just sharing my experience with this issue - Seems to align with the OPs experience.

Summary: CPU constrained guidance means that batching can't scale correctly.

Vllm 0.4.2 Outlines: 0.0.34 lm_format_enforcer: 0.10.2 Model: Llama 3 8b instruct Hardware:

  • Single A100 (80G)
  • AMD EPYC 7V13 (24 cores)

Single request:

Outlines: ~70 tps - CPU 100% lm_format_enforcer: ~45 tps - CPU 100% No guidance: ~140 tps

Batched requests:

Outlines: ~70 tps - CPU 100% lm_format_enforcer: ~45 tps - CPU 100% No guidance: ~1000 tps

Guided regex and json both effected:

Example guidance:

regex
~~~response\n# Content\\n([.\\W\\w]+)\\n{2}~{3}
json
{"type":"object","properties":{"test":{"type":"string"}},"required":["test"]}

lynkz-matt-psaltis avatar May 22 '24 13:05 lynkz-matt-psaltis

Here's line timings for model_executor/guided_decoding/outlines_logits_processors.py

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    41                                               @line_profiler.profile
    42                                               def __call__(self, input_ids: List[int],
    43                                                            scores: torch.Tensor) -> torch.Tensor:
    44                                                   """Use the FSM to bias the logits before sampling the next token."""
    45      2686      22898.6      8.5      0.1          seq_id = hash(tuple(input_ids))
    46
    47      2686       1994.5      0.7      0.0          if len(input_ids) == 0:
    48         3         13.6      4.5      0.0              self.init_state()
    49                                                   else:
    50      2683        953.2      0.4      0.0              last_token = input_ids[-1]
    51      2683      12007.7      4.5      0.0              last_seq_id = hash(tuple(input_ids[:-1]))
    52      5366      15540.9      2.9      0.0              self.fsm_state[seq_id] = self.fsm.next_state(
    53      2683       2226.4      0.8      0.0                  self.fsm_state[last_seq_id], last_token)
    54
    55      2686    2022417.3    752.9      5.2          allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
    56
    57      5372      83383.0     15.5      0.2          mask = torch.full((scores.shape[-1], ),
    58      2686       3307.1      1.2      0.0                            -math.inf,
    59      2686       1901.2      0.7      0.0                            device=scores.device)
    60      2686   36566141.1  13613.6     94.3          mask[allowed_tokens] = 0
    61      2686      36379.1     13.5      0.1          scores.add_(mask)
    62      2686        794.1      0.3      0.0          return scores

lynkz-matt-psaltis avatar May 23 '24 03:05 lynkz-matt-psaltis

Based on that timing breakdown, can you try to replace mask[allowed_tokens] = 0 by using torch index_fill? e.g. mask.index_fill_(0, allowed_tokens, 0) This might be faster than manually indexing the mask tensor.

felixzhu555 avatar May 23 '24 05:05 felixzhu555

I've been doing some further perf analysis and breaking things out a bit to try and understand the bottleneck. Doesn't seem to be related to the indexer but rather, moving the allowed_tokens array around.

cpu first, move to gpu

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    43                                               @line_profiler.profile
    44                                               def __call__(self, input_ids: List[int],
    45                                                            scores: torch.Tensor) -> torch.Tensor:
    46                                                   """Use the FSM to bias the logits before sampling the next token."""
    47      2529      18368.3      7.3      0.1          seq_id = hash(tuple(input_ids))
    48
    49      2529       2418.0      1.0      0.0          if len(input_ids) == 0:
    50         3         22.6      7.5      0.0              self.init_state()
    51                                                   else:
    52      2526        886.1      0.4      0.0              last_token = input_ids[-1]
    53      2526      11457.4      4.5      0.1              last_seq_id = hash(tuple(input_ids[:-1]))
    54      5052      14539.8      2.9      0.1              self.fsm_state[seq_id] = self.fsm.next_state(
    55      2526       1931.4      0.8      0.0                  self.fsm_state[last_seq_id], last_token)
    56
    57      2529    1903376.3    752.6     10.5          allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
    58      2529   15524244.9   6138.5     85.5          allowed_tokens_tensor = torch.tensor(allowed_tokens, dtype=torch.int32, device='cpu')
    59
    60      2529       3262.6      1.3      0.0          if self.mask is None or self.allowed_tokens_tensor is None:
    61      2529      82721.9     32.7      0.5              self.mask = torch.full_like(scores, -math.inf)
    62                                                   else:
    63                                                       self.mask.fill_(-math.inf)
    64                                                   
    65      2529       4009.6      1.6      0.0          if (allowed_tokens_tensor.device != scores.device):
    66      2529     489064.9    193.4      2.7              allowed_tokens_tensor = allowed_tokens_tensor.to(scores.device)
    67                                                       
    68      2529      39004.4     15.4      0.2          allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64)
    69                                                   
    70      2529      35650.6     14.1      0.2          self.mask.index_fill_(0, allowed_tokens_tensor, 0)
    71      2529      23729.9      9.4      0.1          scores.add_(self.mask)
    72
    73      2529        630.8      0.2      0.0          return scores

straight to gpu:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    43                                               @line_profiler.profile
    44                                               def __call__(self, input_ids: List[int],
    45                                                            scores: torch.Tensor) -> torch.Tensor:
    46                                                   """Use the FSM to bias the logits before sampling the next token."""
    47      2252      14057.0      6.2      0.1          seq_id = hash(tuple(input_ids))
    48
    49      2252       1943.8      0.9      0.0          if len(input_ids) == 0:
    50         3         12.1      4.0      0.0              self.init_state()
    51                                                   else:
    52      2249        696.3      0.3      0.0              last_token = input_ids[-1]
    53      2249       9021.9      4.0      0.1              last_seq_id = hash(tuple(input_ids[:-1]))
    54      4498      14201.8      3.2      0.1              self.fsm_state[seq_id] = self.fsm.next_state(
    55      2249       1836.6      0.8      0.0                  self.fsm_state[last_seq_id], last_token)
    56
    57      2252    1692571.2    751.6     10.5          allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
    58      2252   14277941.7   6340.1     88.3          allowed_tokens_tensor = torch.tensor(allowed_tokens, dtype=torch.int32, device=scores.device)
    59
    60      2252       2582.4      1.1      0.0          if self.mask is None or self.allowed_tokens_tensor is None:
    61      2252      55524.5     24.7      0.3              self.mask = torch.full_like(scores, -math.inf)
    62                                                   else:
    63                                                       self.mask.fill_(-math.inf)
    64                                                   
    65      2252       3560.3      1.6      0.0          if (allowed_tokens_tensor.device != scores.device):
    66                                                       allowed_tokens_tensor = allowed_tokens_tensor.to(scores.device)
    67                                                       
    68      2252      34986.8     15.5      0.2          allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64)
    69                                                   
    70      2252      32876.8     14.6      0.2          self.mask.index_fill_(0, allowed_tokens_tensor, 0)
    71      2252      22152.8      9.8      0.1          scores.add_(self.mask)
    72
    73      2252        633.6      0.3      0.0          return scores

lynkz-matt-psaltis avatar May 23 '24 09:05 lynkz-matt-psaltis

58     12693    9401753.9    740.7     17.2          allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
59     12693   42707835.5   3364.7     78.1          np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32)
60     12693      73736.7      5.8      0.1          allowed_tokens_tensor = torch.from_numpy(np_allowed_tokens)

Halved the cost by using a numpy array first also tried torch.as_tensor but no significant changes.

lynkz-matt-psaltis avatar May 23 '24 10:05 lynkz-matt-psaltis

Beyond this, I'm not sure I see a way forward without changes to outlines and lm-format-enforcer to provide the information in a more efficient structure than a List. Does anyone see any memorisation opportunities here to at least reduce the iteration counts?

lynkz-matt-psaltis avatar May 23 '24 10:05 lynkz-matt-psaltis

One thing I think we could do to make it faster is to use the fact that allowed_tokens is either almost all the tokens, or none of the tokens. Currently the mask is created at -math.inf, but we could also create the mask at 0 if the length of allowed_tokens is < scores.shape[0]/2 and then fill_ with -math.inf instead?

jarrelscy avatar May 23 '24 10:05 jarrelscy

I went down that same line of thinking - I don't think the timings above support it however. Its getting the python List into a Tensor that seems to be 80%+ of the cost per iteration. So short of data structure changes upstream, my current thinking is we're left with iteration optimisations - can we avoid going back to fsm.allowed_token_ids in certain situations. Not sure on that yet - still learning how this all fits together.

lynkz-matt-psaltis avatar May 23 '24 10:05 lynkz-matt-psaltis

Are the PRs for this issue currently stalled due to competing priorities?

JGSweets avatar Jun 26 '24 20:06 JGSweets

I believe my PR mitigates the issue, would appreciate some testing to verify https://github.com/outlines-dev/outlines/pull/1013

It decreases the worst case Outlines structured generation logits processor overhead from 50ms to 1ms.

Please let me know if this doesn't resolve this threads issue and another approach is needed.

Edit:

It's available in Outlines main

pip install git+https://github.com/outlines-dev/outlines

lapp0 avatar Jul 02 '24 14:07 lapp0

I did not test the solution yet, but here is the same issue posted to Outlines.

https://github.com/outlines-dev/outlines/issues/1011#event-13374395930

I suspected it might be a threading issue in Vllm logit decoding, but sounds like this thread would’ve picked up on that if that were the case.

robcaulk avatar Jul 03 '24 07:07 robcaulk

Just anecdotally I am still seeing a 20x slowdown doing batch generation with JSONLogitsProcessor enabled, even with @lapp0's PR. Let me know if i can provide more information.

keeth avatar Jul 04 '24 03:07 keeth

Thanks for reporting back

Seems we've tackled one component of slowness, as reported by @lynkz-matt-psaltis "Its getting the python List into a Tensor that seems to be 80%+ of the cost per iteration".

However as suggested by @robcaulk, there may be another source of slowness involving vLLMs Outlines integration involving the ThreadPoolExecutor https://github.com/vllm-project/vllm/blob/80ca1e6a3a28a0373dc00c5b4fe956c16de952fa/vllm/model_executor/guided_decoding/outlines_decoding.py#L71

I also have an alternative hypothesis: iff the issue only occurs when using vLLM + ray, perhaps the state of these logits processors is expensive to communicate between ray workers.

Could you please

    1. rm -rf ~/.cache/outlines just to be sure the cached RegexFSM isn't from a previous version
    1. Share the version details of relevant packages with pip freeze
    1. Provide a reproduction script and hardware details (especially the GPU type and count)
    1. If you've done any profiling share share that as well

lapp0 avatar Jul 04 '24 18:07 lapp0