syncode icon indicating copy to clipboard operation
syncode copied to clipboard

Error when handling partial terminal

Open Gompyn opened this issue 4 months ago • 1 comments

I'm trying to use syncode to generate Java programs, but it fails when given a valid token. I simplify the case in the reproduction code.

I use a simple test case to demonstrate the case. In the grammar, there are four types of terminals: WS, NUMBER, BLOCK_COMMENT and operators e.g. + or *. The grammar only allows simple operations, and both WS and BLOCK_COMMENT are ignored.

When the text starts with /**/, which by itself is a legal BLOCK_COMMENT, it is divided by the tokenizer into /** and /. Then, if the /** token is selected by the model, because it is not a complete BLOCK_COMMENT, it is lexed into /, * and *, and then causes error like this.

[YYYY-MM-DD hh:mm:ss,sss -syncode.grammar_mask.grammar_constrainer] - Parsing failed!
Partial code: /**
Parsed lexical tokens: [Token('SLASH', '/')]
[YYYY-MM-DD hh:mm:ss,sss -syncode.grammar_mask.grammar_constrainer] - --------------------------------------------------
Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/syncode/larkm/parsers/lalr_parser_state.py", line 77, in feed_token
    action, arg = states[state][token.type]
                  ~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'SLASH'

This is because, the terminals are lexed purely based on the visible partial code, but doesn't consider the possibility that the partial code could be a prefix of another terminal. When it fails to match one terminal in the partial code (BLOCK_COMMENT in this case), it turns to match another code. This can be seen from the following match method of the Scanner class who does the real lexing job for the partial code. This behavior is totally fine in the lark library, because the code is almost always complete, but it is not OK for a constrained decoding method like syncode. https://github.com/structuredllm/syncode/blob/8afb425570440afcaada961b2da483867d9f2ff8/syncode/larkm/lexer.py#L387-L391

Reproduction code

from syncode.grammar_mask.grammar_constrainer import GrammarConstrainer
from syncode.parsers.grammars import Grammar
from syncode.mask_store.byte_tokenizer import ByteTokenizer

from transformers import AutoTokenizer
import torch

class FakeGrammar(Grammar):
  def __init__(self, name: str, ebnf: str):
    self.ebnf = ebnf
    self.name = name

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B')
grammar = FakeGrammar('custom',
r'''start: sum
sum: sum ("+" | "-") term
term: term ("*" | "/") factor
factor: "(" sum ")" | NUMBER
NUMBER: /[0-9]+/
BLOCK_COMMENT: /\/\*.*?\*\//s
%ignore BLOCK_COMMENT
%import common.WS
''')
byte_tokenizer = ByteTokenizer(tokenizer)
grammar_constrainer = GrammarConstrainer(
  grammar=grammar,
  tokenizer=tokenizer, # type: ignore
  byte_tokenizer=byte_tokenizer,
  use_cache=True,
  parse_output_only=True,
  batch_size=1,
  dev_mode=True,
  parser='lalr',
  mode='grammar_strict',
  indent=False
)

vocab_size = len(tokenizer.get_vocab())

def process_logit(input_ids: list[int], logit: torch.Tensor) -> torch.Tensor:
  return grammar_constrainer.mask_scores(torch.tensor([input_ids]), logit.unsqueeze(0)).squeeze(0) # type: ignore

def process_tokens(tokens: list[int]):
  for i in range(len(tokens)):
    logit = torch.zeros((vocab_size,), dtype=torch.float)
    visible_tokens = tokens[:i]
    masked_logit = process_logit(visible_tokens, logit)
    assert masked_logit[tokens[i]] != float('-inf'), f'token {i} ({tokens[i]}, {tokenizer.decode(tokens[i])!r}) is masked'

text = '/**/1+1'
tokens = tokenizer.encode(text)
print(tokens)
process_tokens(tokens)

Gompyn avatar Aug 20 '25 15:08 Gompyn

lark-parser/lark#1344 is probably related, where they emphasize the importance of a custom lexer in such cases.

Gompyn avatar Aug 20 '25 15:08 Gompyn