Error when handling partial terminal
I'm trying to use syncode to generate Java programs, but it fails when given a valid token. I simplify the case in the reproduction code.
I use a simple test case to demonstrate the case. In the grammar, there are four types of terminals: WS, NUMBER, BLOCK_COMMENT and operators e.g. + or *. The grammar only allows simple operations, and both WS and BLOCK_COMMENT are ignored.
When the text starts with /**/, which by itself is a legal BLOCK_COMMENT, it is divided by the tokenizer into /** and /. Then, if the /** token is selected by the model, because it is not a complete BLOCK_COMMENT, it is lexed into /, * and *, and then causes error like this.
[YYYY-MM-DD hh:mm:ss,sss -syncode.grammar_mask.grammar_constrainer] - Parsing failed!
Partial code: /**
Parsed lexical tokens: [Token('SLASH', '/')]
[YYYY-MM-DD hh:mm:ss,sss -syncode.grammar_mask.grammar_constrainer] - --------------------------------------------------
Traceback (most recent call last):
File ".../lib/python3.11/site-packages/syncode/larkm/parsers/lalr_parser_state.py", line 77, in feed_token
action, arg = states[state][token.type]
~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'SLASH'
This is because, the terminals are lexed purely based on the visible partial code, but doesn't consider the possibility that the partial code could be a prefix of another terminal. When it fails to match one terminal in the partial code (BLOCK_COMMENT in this case), it turns to match another code. This can be seen from the following match method of the Scanner class who does the real lexing job for the partial code. This behavior is totally fine in the lark library, because the code is almost always complete, but it is not OK for a constrained decoding method like syncode.
https://github.com/structuredllm/syncode/blob/8afb425570440afcaada961b2da483867d9f2ff8/syncode/larkm/lexer.py#L387-L391
Reproduction code
from syncode.grammar_mask.grammar_constrainer import GrammarConstrainer
from syncode.parsers.grammars import Grammar
from syncode.mask_store.byte_tokenizer import ByteTokenizer
from transformers import AutoTokenizer
import torch
class FakeGrammar(Grammar):
def __init__(self, name: str, ebnf: str):
self.ebnf = ebnf
self.name = name
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B')
grammar = FakeGrammar('custom',
r'''start: sum
sum: sum ("+" | "-") term
term: term ("*" | "/") factor
factor: "(" sum ")" | NUMBER
NUMBER: /[0-9]+/
BLOCK_COMMENT: /\/\*.*?\*\//s
%ignore BLOCK_COMMENT
%import common.WS
''')
byte_tokenizer = ByteTokenizer(tokenizer)
grammar_constrainer = GrammarConstrainer(
grammar=grammar,
tokenizer=tokenizer, # type: ignore
byte_tokenizer=byte_tokenizer,
use_cache=True,
parse_output_only=True,
batch_size=1,
dev_mode=True,
parser='lalr',
mode='grammar_strict',
indent=False
)
vocab_size = len(tokenizer.get_vocab())
def process_logit(input_ids: list[int], logit: torch.Tensor) -> torch.Tensor:
return grammar_constrainer.mask_scores(torch.tensor([input_ids]), logit.unsqueeze(0)).squeeze(0) # type: ignore
def process_tokens(tokens: list[int]):
for i in range(len(tokens)):
logit = torch.zeros((vocab_size,), dtype=torch.float)
visible_tokens = tokens[:i]
masked_logit = process_logit(visible_tokens, logit)
assert masked_logit[tokens[i]] != float('-inf'), f'token {i} ({tokens[i]}, {tokenizer.decode(tokens[i])!r}) is masked'
text = '/**/1+1'
tokens = tokenizer.encode(text)
print(tokens)
process_tokens(tokens)
lark-parser/lark#1344 is probably related, where they emphasize the importance of a custom lexer in such cases.