vllm
vllm copied to clipboard
Add Grammars
Fixes https://github.com/vllm-project/vllm/issues/1229
Implement incremental LALR / Regex parser to determine legal-next-token set.
Try it
I smoke tested with
dockerfile: https://hub.docker.com/r/lapp0/vllm_grammar_branch (commit 2b2b024)
async def fetch_response(session, doc, grammar, api_url_base, timeout=60):
prompt = get_prompt(doc)
grammar = get_grammar(doc)
headers = {"User-Agent": "Test Client"}
pload = {
"model": "YOUR HF MODEL URI",
"prompt": prompt,
"grammar": grammar,
# optional:
#"n": 8,
#"use_beam_search": True,
#"temperature": 0.0,
#"add_generation_prompt": False,
#"max_tokens": 4096,
#"logprobs": True
}
async with session.post(api_url=f"{api_url_base}/v1/completions", headers=headers, json=pload, timeout=timeout) as response:
return await response.json()
async def do_smoke_test(api_url_base, max_concurrent=16, timeout=2000):
documents = json.load(open("smoke_testing_docs.json"))[:256]
results = {}
connector = aiohttp.TCPConnector(limit=max_concurrent)
session_timeout = aiohttp.ClientTimeout(
total=None,
sock_connect=timeout,
sock_read=timeout,
)
async with aiohttp.ClientSession(connector=connector, timeout=session_timeout) as session:
tasks = [fetch_response(session, doc, api_url_base, timeout) for doc in documents]
for doc, response in (await tqdm.asyncio.tqdm.gather(*tasks)):
results[doc] = response
TODO
InteractivePredictiveLALRParser- [x] Load Lark EBNF grammars
- [x] Track terminals and state transitions
- [x] Determine token validity whether it partially or wholly completes the terminals pattern
TokenTrie- [x] Efficiently retrieve candidate tokens
NextTokenValidator- [x] Update parser state
- [x] Get valid token IDs given a tokenizer vocabulary and the completion-in-process' text
GrammarLogitProcessor- [x] Implement
def __call__(self, token_ids, logits), which updates the parsers state with new tokens and filters based onNextTokenValidator's valid token IDs
- [x] Implement
- Testing
- [x] Performance tests: see docs
- [/] Documentation
- [x] Clean up tests, right now they're all lurking in grammar.py
- [x] Handle EOS token
- [x] finish test cases
IncrementalParserState- refactor into immutable parser state
Ramblings from previous implementation:
Grammar Token Filter Algorithm
The GTF algorithm involves calculating the set of valid next-tokens given an incomplete sequence and a grammar.
Its core components are
- A set of tokens, for which every token is either a single character or can be generated with a combination of other tokens
- An interactive parser which can determine the validity of incomplete sequencess.
Lazy Approach
The simplest algorithm is
def get_valid_tokens(base_sequence, token_vocabulary):
valid_tokens = set()
for token in token_vocabulary:
if parser.is_valid_sequence(base_sequence + token)
valid_tokens.add(token)
return valid_tokens
This approach has two core inefficiencies:
- If the token "foobar" is valid, then we already know "foo" is valid, thus redundant work is performed
- Parser applies state transitions for
base_sequenceonce for every token
This PR's Approach
The current GTF algorithm improves on the lazy approach in two aspects:
token_vocabularyis a trie, allowing us to check "foo", and if invalid, we know "foobar", "foobaz", and "foobarbaz" are also invalid.- The
parseris interactive, meaning it doesn't need to recalculate thebase_sequenceeach time.
Current approach algorithm is a depth first search of the token trie with a base_sequence-warmed parser.
def get_valid_tokens(parser, token_trie, trie_root=""):
valid_tokens = set()
for token in token_trie.children(trie_root):
if parser.is_valid_next_token(token):
child_parser = parser.step_seq(token)
valid_tokens.add(token)
valid_tokens |= get_valid_tokens(child_parser, token_trie, token)
return valid_tokens
The main weakness of this implementation involves regular expressions. If a terminal rule is a regular expression, an incomplete match must be searched for redundantly.
Optimal (Future) Approach
The optimal GTF approach involves all terminal rules being a single character. All terminals, including regular expressions must be decomposed.
For example the regular expression
\d{5}(-\d{4})?
Must be decomposed into
digit = "\d"
five_digits = digit, digit, digit, digit, digit;
four_digits = digit, digit, digit, digit;
optional_suffix = "-", four_digits;
zipcode = five_digits, [optional_suffix];
Additionally we can use a helper function legal_chars(character_expr) which retrieves all characters legal within a character regexp, e.g.
legal_chars("\d") = set(["0", "1", "2", ...])legal_chars("[ae") = se["a", "e"]
With this optimization the GTF algorithm would be as follows:
def get_valid_tokens(parser, token_trie, token_trie_roots=None):
valid_tokens = set()
for next_terminal_expr in parser.get_next_terminals_exprs():
parser_next_chars = legal_chars(next_terminal_expr)
legal_token_prefixes = token_trie.children(token_trie_roots) | parser_next_chars
if legal_token_prefixes:
child_parser = parser.transition(next_terminal_expr)
legal_token_suffixes = get_valid_tokens(child_parser, token_trie, legal_token_prefixes)
valid_tokens |= trie.combine(legal_token_prefixes, legal_token_suffixes)
return valid_tokens
This function requires only applying a state transition once for every transition which is legal within the token set. As opposed to the current implementation which applies a state transition once for each unique token trie node.
Breaking down into single character terminals provides another advantage: we don't have to recompute a regular expression partial redundantly, if foo matches (foo|bar)(bazbif), we don't need to recalculate the entire regex for foobaz again. In fact, we don't compute regular expresisons at all, we simply generate the valid character set for a given atomic character expression and intersect it with the tries valid token prefix set.
Example
I use a simple Thompson's-style regex to generate the eNFA dict via automata_toolkit.
Sample code which assigns random values to logits and generates a grammer-constrained completion:
regexp = r"(large )?(language )((models )+(inference engines ))(are )((useful)+((very )*complex))."
sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts)))
for i in range(4):
logit_processor = TokenConstraintLogitProcessor(
tokenizer=tokenizer,
nfa=EpsilonNFA(nfa=regex_to_nfa.regex_to_nfa(regexp)),
)
token_ids = []
while True:
logits = logit_processor(
token_ids=token_ids,
logits=np.random.uniform(-10, 10, len(tokenizer.vocab))
)
new_token_id = sample_from_logits(logits)
token_ids.append(new_token_id)
if new_token_id == tokenizer.eos_token_id:
break
print(f"run #{i}")
print("\ttokenid", token_ids)
print("\ttokens:", [tokenizer.decode(tok_id, ) for tok_id in token_ids])
print("\tresult:", tokenizer.decode(token_ids, skip_special_tokens=False))
Output:
regexp: r"(large )?(language )((models )+(inference engines ))(are )((useful)+((very )*complex))."
run #0
tokenid [2220, 28712, 28721, 104, 305, 28708, 113, 2851, 28708, 490, 3418, 5149, 28713, 264, 267, 1001, 112, 452, 720, 49, 2]
tokens: ['la', 'r', 'g', 'e', 'l', 'a', 'n', 'gu', 'a', 'ge', 'mo', 'del', 's', 'a', 're', 'co', 'm', 'pl', 'ex', '.', '</s>']
result: large language models are complex.</s>
run #1
tokenid [2220, 7879, 104, 28705, 28714, 2374, 465, 4319, 9417, 358, 17048, 597, 104, 28705, 675, 452, 720, 49, 2]
tokens: ['la', 'rg', 'e', '', 'l', 'angu', 'age', 'inf', 'eren', 'ce', 'engines', 'ar', 'e', '', 'com', 'pl', 'ex', '.', '</s>']
result: large language inference engines are complex.</s>
run #2
tokenid [2220, 7879, 104, 28705, 4730, 120, 465, 968, 1190, 264, 267, 1429, 28724, 4630, 49, 2]
tokens: ['la', 'rg', 'e', '', 'lang', 'u', 'age', 'mod', 'els', 'a', 're', 'ver', 'y', 'complex', '.', '</s>']
result: large language models are very complex.</s>
run #3
tokenid [16962, 543, 113, 28721, 120, 465, 4319, 9417, 28717, 104, 2536, 1303, 597, 104, 332, 28713, 797, 120, 28714, 49, 2]
tokens: ['large', 'la', 'n', 'g', 'u', 'age', 'inf', 'eren', 'c', 'e', 'eng', 'ines', 'ar', 'e', 'u', 's', 'ef', 'u', 'l', '.', '</s>']
result: large language inference engines are useful.</s>
Please observe that ["la", "rg", "e"] and ["large"] are both valid tokens within the grammar, and either may be generated.
Thanks for putting this together @lapp0 .
I managed to integrate it with the rest of the vllm and got legit outputs!
A minor change I made was that the return from GrammarLogitProcessor.__call__ should be tensor, not a list.
I made a minor change to make it work, hope it helps.
N = len(self.tokenizer.vocab)
mask = torch.zeros(N, dtype=torch.bool)
valid = torch.tensor(valid_token_ids, dtype=torch.long)
mask[valid] = True
logits[~mask] = float('-inf')
return logits
Appreciate your review, fix, and interest @xuy. Will integrate that after I'm done with some bug fixes!
What needs to happen to get this grammar pull request approved and merged? I'd love to start using grammars with vLLM.
The code had a major bug - only single character tokens were being selected.
I just pushed a fix which fixes the bug, makes the parser functional and immutable, caches the parser based on the state stack, and is a bit cleaner.
Currently validating
~~25 tokens per second~~
~~39 tokens per second~~
- ~~46 token per second with nproc = 2~~
- ~~67 tokens per second with nproc = 4~~
- ~~74 tokens per second with nproc = 8~~
(Moved multiprocessing to a separate grammar-multiprocessing branch)
~~47 tokens per second~~
58 tokens per second
on a single core of my laptop, indicating this would be a bottleneck for vLLM. Will try to optimize further.
Does this only work with the OpenAI API at the moment? If so, could it be added to the vllm api as well?
Works nicely so far. I noticed the preprocessing for batching being done on only one core and hence significantly stalling the process. Is that due to grammar implementation? And is there a way to fix that, to either use GPU or more than a single core?
@lapp0 Could you post your multiprocessing branch, even if its incomplete? I've been trying to implement it myself, but it seems I can't get it quite right.
@brucethemoose It's pretty poorly implemented, but here you go: https://github.com/lapp0/vllm/tree/grammar-multiprocessing
I've been working on integrating some of my caching changes into https://github.com/outlines-dev/outlines which already has regex-based guidance for vLLM.
Tested the grammar support from your branch.
Additional changes I made:
- Rebased it on latest vLLM
main - Added receiving the
grammarparameter toChatCompletionRequestas well and handled it the same way in thev1/chat/completionrequest handler, since I use an instruct fine-tuned model via that path.
Model: TheBloke/deepseek-coder-33B-instruct-AWQ
System: "You are a helpful AI assistant. You give concise answers. If you do not know something, then say so."
User: "Write down the first 10 prime numbers as a comma separated list on a single line. Do not write anything else."
Without the grammar the model gives this response:
"2, 3, 5, 7, 11, 13, 17, 19, 23, 29"
So in the grammar I intentionally denied any use of white-space, so the expected output must be:
"2,3,5,7,11,13,17,19,23,29"
Grammar:
?start: SIGNED_NUMBER ( "," SIGNED_NUMBER )*
%import common.SIGNED_NUMBER
While it conforms to the grammar it fails to produce the two digit prime numbers:
"2,3,5,7,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,"
It may happen that the grammar code somehow denies it from writing 11 there, e.g. it cannot write a number with multiple digits.
Changed the grammar to be more strict and simpler:
?start: DIGIT+ ( "," DIGIT+ )*
%import common.DIGIT
With this grammar the model produces the primes, but cannot stop. Therefore there is a problem in the code denying it to generate the EOS token. Generating the stop token should be allowed wherever it is consistent with the grammar.
"2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,"
Grammar support would be really awesome to have for my use case. I actually started implementing Lark support and already figured out the PR's algorithm in my tests outside vLLM when I found your PR. It is really great that you had this much progress already, so there is a chance to have grammar support soon.
Even if we cannot use regex in our grammar, having any kind of grammar support would still be a huge win. Also, the grammar support would allow for reliable function calling, which is also in the works (#2360). They refactor the spaghetti code in the OpenAI compatible server in that first PR.
In llama.cpp there is a grammar named GBNF, which is an EBNF variant. That already works and its integration with the sampling code can give us some ideas on how to optimize this in vLLM.
When I change the grammar to allow whitespace, then it can generate the primes properly:
"2, 3, 5, 7, 11, 13, 17, 19, 23, 29"
Grammar:
?start: _WS? DIGIT+ ( _WS? "," _WS? DIGIT+ )* _WS?
%import common.DIGIT
%import common.WS -> _WS
It worked without code changes and could stop. I do not know why it cannot stop if no white-space is allowed. The model writes spaces only after the commas and no newline is generated at the end of completion.
Narrowed it down to this grammar. It works, but has to produce a newline at the end, so it can stop:
?start: DIGIT+ ( "," DIGIT+ )* _WS?
%import common.DIGIT
%import common.WS -> _WS
Output as a Python string:
"2,3,5,7,11,13,17,19,23,29\n"
So the actual bug is that the grammar does not let the LLM to generate a stop token if the grammar does not allow white-space at the very end. At least that seems to be the case based on the few tests I've done. I'm not sure whether it is a limitation of this specific LLM due to how it was trained (must write a newline before EOS) or a bug in the sampler integration of the grammar (it does not allow EOS in that case).
The CPU overhead of the grammar is indeed horrible. Speed is down from 32T/s to 8.4T/s with the above very simple grammar.
@viktor-ferenczi
The parser doesn't handle ambiguous terminals well. Could you try converting them to a rule? Something along the lines of
signed_number: ["+"|"-"] number
number: float | int
float: int exp | decimal exp?
decimal: int "." int? | "." int
exp: ("e"|"E")
signed_int signed_int: ["+"|"-"] int
int: DIGIT+ DIGIT: "0".."9"
And yes, the speed is bad. Outlines addresses this by precompiling the regex FSM and using Numba. I'm leaning heavily towards thinking vLLM should be a strong, simple inference engine and outlines should be a wrapper on top for grammars.
Outlines vLLM CFG implementation merged yesterday https://github.com/outlines-dev/outlines/pull/517
The grammar you suggested crashes vLLM with this exception:
TypeError: UnexpectedToken.__init__() missing 2 required positional arguments: 'token' and 'expected'
The traceback is useless because of the use of Ray (2 GPUs).
Performance: I was running vLLM with cProfile and executed the completion some 50 times in about 2 minutes. Found the grammar responsible for only ~550ms of CPU runtime, so I don't see from the profiling data where the experienced slowdown is. Grammar's CPU load is 60-70% of a core, so it does not seem to be CPU bound there. I guess the load does not show up on the CPU or the Python profiler, but introduced by the use of Tensor (GPU RAM access?) or similar. I don't know enough Torch and CUDA yet to tell exactly.
I'm leaning heavily towards thinking vLLM should be a strong, simple inference engine and outlines should be a wrapper on top for grammars.
Where would you put the grammar support? If we keep it inside vLLM, then it can be used via the REST APIs. That's what I prefer, at least for my use case. It allows for hosting the LLM separately from the application and better scalability, all without having to write a custom server for each application or forcing the application to run the LLM directly in-process.
The exception due to the grammar:
...
File "/home/viktor/dep/vllm-contrib/vllm/model_executor/layers/sampler.py", line 155, in _apply_logits_processors
logits_row = logits_processor(token_ids, logits_row)
File "/home/viktor/dep/vllm-contrib/vllm/grammar.py", line 472, in __call__
return ray.get(result_id)
... ray ...
ray.exceptions.RaySystemError: System error: Failed to unpickle serialized exception
... ray ...
TypeError: UnexpectedToken.__init__() missing 2 required positional arguments: 'token' and 'expected'
So it cannot relay Lark's UnexpectedToken error. It is also not handled properly and turned into a Bad Request error by the API Server, apparently.
Performance: I was running vLLM with cProfile and executed the completion some 50 times in about 2 minutes. Found the grammar responsible for only ~550ms of CPU runtime, so I don't see from the profiling data where the experienced slowdown is. Grammar's CPU load is 60-70% of a core, so it does not seem to be CPU bound there. I guess the load does not show up on the CPU or the Python profiler, but introduced by the use of Tensor (GPU RAM access?) or similar. I don't know enough Torch and CUDA yet to tell exactly.
Are you using multiple GPUs? I'm seeing a substantial slowdown when passing the tensors to the logits processor ray actor.
Where would you put the grammar support? If we keep it inside vLLM, then it can be used via the REST APIs. That's what I prefer, at least for my use case. It allows for hosting the LLM separately from the application and better scalability, all without having to write a custom server for each application or forcing the application to run the LLM directly in-process.
https://outlines-dev.github.io/outlines/reference/vllm/
@lapp0 Tried the outlines.serve.serve way. The JSON schema and Regex work, but the grammar (cfg) does not. See the outlines bug report on this. Also, that solution does not work with tensor parallel at all (see bug ticket). It looks like everything is implemented, just not reliable yet.
There are already libraries actively maintained for guided generation that can integrate with vLLM, like Outlines. I would be wary of introducing code that is tangentially related to this library and will require a substantial amount of maintenance when this can be solved by an import. Why not contribute this code to these libraries and import them here instead?
http://outlines-dev.github.io/outlines/reference/vllm/
@jqueguiner The custom logits processors need some more information to be passed to avoid having to patch vLLM the hard way. Primary example is a way to identify the sequence (seq_id) and maybe more. Please look into the implementation of outlines.serve.serve, specifically _patched_apply_logits_processors.
The seq_id can probably be replaced with a hash of the token ids if that’s really the blocker. But that’s beside the point, even if we needed to pass seq_id, I agree with @jqueguiner that it’s an easier change for the vLLM team that requires substantially less maintenance over time.
@jqueguiner The custom logits processors need some more information to be passed to avoid having to patch vLLM the hard way. Primary example is a way to identify the sequence (
seq_id) and maybe more. Please look into the implementation of outlines.serve.serve, specifically_patched_apply_logits_processors.
This experimental change where the state is cached by the hash of the prior token ids is working for me so far:
https://github.com/outlines-dev/outlines/commit/8b1ff9a7863915f5fdee421d4cae8f0840d58b33#diff-f65ffb5f52b2e358c713ccb8f32a700769426c6c8b655f689e3cdccae07d22ac
A hash on preceding tokens is even better than seq_id, because it would allow for further optimization should prompts be repeated.
Hi everyone, thank you so much for the very active discussion here. As vLLM maintainer, I want to express my sincere thanks for your enthusiasm. vLLM as a project is focused on optimizing LLM inference and provide a fully compatible OpenAI API; constrained decoding is not our strong suit, and we don't have the expertise to maintain it.
@lapp0 would you be able to consider closing this PR and merge into outlines instead? I think you mentioned it here. I would very much like to use outlines directly in vLLM after #2488 is merged. (or sooner, adding it to completion API is another option).
@lapp0 and @viktor-ferenczi, please let us know what interface and scheduling change on the vLLM side is needed to better support this functionality.
Sure @simon-mo will follow up with you for any changes to vLLM which are necessary. Thanks for your enthusiastic support!
Closing in favor of outlines. A few changes necessary in outlines to consider guidance ready for vLLM:
- https://github.com/outlines-dev/outlines/issues/524
- https://github.com/outlines-dev/outlines/pull/539