vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Performance]: Transformers 4.45.1 slows down `outlines` guided decoding

Open joerunde opened this issue 1 year ago • 1 comments

Report of performance regression

I noticed that guided decoding was a bit slower on newer builds of vllm, but couldn't track down a commit that caused a performance regression. Instead it looks like upgrading transformers from 4.44.2 to 4.45.1 causes the issue.

I ran a small artillery test with requests using guided decoding, using the code from commit 4f1ba0844. This is the last commit before mllama support was added, so it's the last point where vllm will work with both transformers versions 4.44.2 and 4.45.1. VLLM was run with 1xA100 gpu, using model mistralai/Mistral-7B-Instruct-v0.2

The results with 4.44.2 installed:

http.codes.200: ................................................................ 240
http.downloaded_bytes: ......................................................... 91928
http.request_rate: ............................................................. 3/sec
http.requests: ................................................................. 240
http.response_time:
  min: ......................................................................... 105
  max: ......................................................................... 16348
  mean: ........................................................................ 6655.3
  median: ...................................................................... 3905.8
  p95: ......................................................................... 15526
  p99: ......................................................................... 16159.7
http.responses: ................................................................ 240
vusers.completed: .............................................................. 60
vusers.created: ................................................................ 60
vusers.created_by_name.Test completions: ....................................... 60
vusers.failed: ................................................................. 0
vusers.session_length:
  min: ......................................................................... 15318.1
  max: ......................................................................... 38021.7
  mean: ........................................................................ 26628.2
  median: ...................................................................... 27730.6
  p95: ......................................................................... 33199.7
  p99: ......................................................................... 35964.9

and with 4.45.1 installed:

http.codes.200: ................................................................ 240
http.downloaded_bytes: ......................................................... 92209
http.request_rate: ............................................................. 3/sec
http.requests: ................................................................. 240
http.response_time:
  min: ......................................................................... 100
  max: ......................................................................... 27083
  mean: ........................................................................ 10279.2
  median: ...................................................................... 5065.6
  p95: ......................................................................... 26115.6
  p99: ......................................................................... 27181.5
http.responses: ................................................................ 240
vusers.completed: .............................................................. 60
vusers.created: ................................................................ 60
vusers.created_by_name.Test completions: ....................................... 60
vusers.failed: ................................................................. 0
vusers.session_length:
  min: ......................................................................... 19387.6
  max: ......................................................................... 55055.6
  mean: ........................................................................ 41123.7
  median: ...................................................................... 43928
  p95: ......................................................................... 51550.2
  p99: ......................................................................... 53654.1

The slowdown looks pretty significant to me 🐌🐌🐌

I wasn't able to get the vllm profiling to work to try to dig in at all, unfortunately it kept crashing with encoding errors whenever I ran any requests with guided decoding. So, I don't know if this is a problem with vllm, with outlines, or with transformers. But given that outlines hasn't been updated in quite a while and sglang went and forked it- I'm not sure if this is worth investigating as is or if it'll be overcome by events.

Anybody have ideas about what could be going wrong?

The scripts I ran:

artillery.yaml

config:
  timeout: 100
  target: http://rundemc-dev-service:8000
  phases:
    - duration: 180
      arrivalRate: 1
      name: Load test

  payload:
    # path is relative to the location of the test script
    path: 'payloads.csv'
    fields:
      - prompt
    name: unused

  variables:
    model_id:
      - "mistralai/Mistral-7B-Instruct-v0.2"
    backend:
      - "lm-format-enforcer"


scenarios:
  - name: Test completions
    flow:
      - post:
          url: "/v1/completions"
          json:
            model: "{{ model_id }}"
            prompt: "{{ prompt }}"
            max_tokens: 40
      - post:
          url: "/v1/completions"
          json:
            model: "{{ model_id }}"
            prompt: "{{ prompt }}"
            max_tokens: 40
            guided_decoding_backend: "{{ backend }}"
            guided_choice:
              - "foo"
              - "bar"
              - "baz"
              - "buzz"
      - post:
          url: "/v1/completions"
          json:
            model: "{{ model_id }}"
            prompt: "{{ prompt }}"
            max_tokens: 40
            guided_decoding_backend: "{{ backend }}"
            response_format:
              type: "json_object"
      - post:
          url: "/v1/completions"
          json:
            model: "{{ model_id }}"
            prompt: "{{ prompt }}"
            max_tokens: 40
            guided_decoding_backend: "{{ backend }}"
            guided_json:
              type: "object"
              properties:
                name:
                  type: string
                age:
                  type: integer

payloads.csv

"hello world this is jesus"
"Lorem ipsum dolor"
"Write a function that sums two numbers together"

(obviously very scientific 😉 )

Before submitting a new issue...

  • [X] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

joerunde avatar Oct 02 '24 22:10 joerunde

cc @RonanKMcGovern since you've been running into issue as well.

DarkLight1337 avatar Oct 03 '24 02:10 DarkLight1337

Apparently the latest release of outlines has lots of performance enhancement https://github.com/dottxt-ai/outlines/releases/tag/0.1.0

mgoin avatar Oct 08 '24 02:10 mgoin

@joerunde wanted to follow up on this. I am noticing that latencies are super high while using guided decoding with vllm - https://github.com/dottxt-ai/outlines/issues/1241.

What is in your opinion the best way currently to close the gap in latencies? Have you tried the latest version of outlines to see if that helps with latencies?

DhruvaBansal00 avatar Nov 04 '24 22:11 DhruvaBansal00

So this is an issue I have been digging into. The biggest problem appears to be the call

self._guide.get_next_instruction

in BaseLogitsProcessor

In this specific capture its 27% but it can be as high as 99% depending on the input.

So there isn't likely a good fix to make this code faster. However, it only runs in the engine process. So there is a single thread / process that has to handle ALL of these. So everything stalls while the engine process is pegged at 100%. It would be nice to see if we can find a way to do this call async across multiprocessing or something.

Also we are kinda eating it on that np.array call. It should use a cached array here or something if we can avoid the copy. It might be a option to look into having outlines be numpy aware.

Line # Hits Time Per Hit % Time Line Contents

47                                                   def inner_call(input_ids, scores):                                              
48                                                       """Use the FSM to bias the logits before sampling the next token."""        
49         1       2200.0   2200.0      0.0              seq_id = hash(tuple(input_ids))                                             
50                                                                                                                                   
51         1        710.0    710.0      0.0              if len(input_ids) > 0:                                                      
52         1        340.0    340.0      0.0                  last_token = input_ids[-1]                                              
53         1       3230.0   3230.0      0.0                  last_seq_id = hash(tuple(input_ids[:-1]))                               
54         2     143240.0  71620.0      1.7                  self._fsm_state[seq_id] = self._guide.get_next_state(                 
55         1       1550.0   1550.0      0.0                      state=self._fsm_state[last_seq_id], token_id=last_token)            
56                                                       else:                                                                       
57                                                           # Note: this is a hack.                                                 
58                                                           # Lark pickling does not work properly (silent failure),                
59                                                           # which breaks the RPC (which uses python pickleing).                   
60                                                           # We need to find a better solution.                                    
61                                                           # On the first time this is called, we simply re-create                 
62                                                           # the Lark object.                                                      
63                                                           if isinstance(self._guide, CFGGuide):                                   
64                                                               self._guide.parser = Lark(                                          
65                                                                   self._guide.cfg_string,                                         
66                                                                   parser="lalr",                                                  
67                                                                   lexer="contextual",                                             
68                                                                   propagate_positions=False,                                      
69                                                                   maybe_placeholders=False,                                       
70                                                                   regex=True,                                                     
71                                                                   import_paths=[grammars.GRAMMAR_PATH],                           
72                                                               )                                                                   
73                                                                                                                                   
74         2    2308623.0    1e+06     27.3              instruction = self._guide.get_next_instruction(                             
75         1        320.0    320.0      0.0                  state=self._fsm_state[seq_id])                                          
76                                                                                                                                   
77         1       1020.0   1020.0      0.0              if type(instruction) == Generate:  # noqa: E721                             
78         1        460.0    460.0      0.0                  allowed_tokens = instruction.tokens                                     
79                                                       elif type(instruction) == Write:  # noqa: E721                              
80                                                           # TODO: support fast forward tokens                                     
81                                                           allowed_tokens = [instruction.tokens[0]]                                
82                                                       else:                                                                       
83                                                           raise TypeError(                                                        
84                                                               f"Unsupported instruction type {type(instruction)}")                
85                                                                                                                                   
86         2     126930.0  63465.0      1.5              mask = torch.full((scores.shape[-1], ),                                     
87         1       1250.0   1250.0      0.0                              -torch.inf,                                                 
88         1       1990.0   1990.0      0.0                              device=scores.device)                                       
89                                                       # The tokenizer may support more token ids than the model can generate,     
90                                                       # eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256      
91                                                       # but scores.shape == torch.Size([128256])                                  
92                                                       # Using NumPy is faster for filtering token ids                             
93         1    5343328.0    5e+06     63.2              allowed_tokens = np.array(allowed_tokens, dtype=np.int64)                   
94         1     280420.0 280420.0      3.3              allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)         
95         2     168360.0  84180.0      2.0              allowed_tokens = allowed_tokens.masked_select(                              
96         1      40100.0  40100.0      0.5                  allowed_tokens < scores.shape[-1])                                      
97         1      17060.0  17060.0      0.2              mask.index_fill_(0, allowed_tokens, 0)                                      
98         1      17080.0  17080.0      0.2              scores.add_(mask)                                                           
99         1       2740.0   2740.0      0.0              return scores 

iratebadger avatar Nov 07 '24 05:11 iratebadger

Outlines maintainer here, we're currently working on it!

rlouf avatar Feb 23 '25 08:02 rlouf

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

github-actions[bot] avatar Jul 24 '25 02:07 github-actions[bot]

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

github-actions[bot] avatar Aug 24 '25 02:08 github-actions[bot]