syncode
syncode copied to clipboard
Incorrect constrained generation for regex
In the following example, the constrained model fails to generate a string that matches the specified grammar:
import traceback
import lark
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from syncode import Grammar, SyncodeLogitsProcessor
grammar_str = r"""
start: print_expr
// COMMENT THE LINE BELOW AND IT WILL WORK
| update_expr
print_expr: "print" "(" field_access ("," field_access)* ")"
update_expr: field_access "=" quoted_str
field_access: "movie" "[" field_str "]"
field_str: "'" field_name "'"
| "\"" field_name "\""
field_name: "title"
| "director"
| "cast"
| "country"
| "date_added"
| "release_year"
| "rating"
| "duration"
quoted_str: "'" /[^']+/ "'"
| "\"" /[^"]+/ "\""
%ignore " "
"""
prompt = "Given the `movie` record, print the name, the year it was premiered, and its runtime.\n"
model_name = "deepseek-ai/deepseek-coder-1.3b-base"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="balanced_low_0")
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
args = {
"max_new_tokens": 128,
"do_sample": False,
"num_beams": 1,
"num_return_sequences": 1,
}
parser = lark.Lark(grammar_str)
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
syncode_grammar = Grammar(grammar_str)
syncode_logits_processor = SyncodeLogitsProcessor(
grammar=syncode_grammar,
tokenizer=tokenizer,
parse_output_only=True,
num_samples=args["num_beams"],
mode="grammar_strict",
)
syncode_logits_processor.reset(prompt)
generation_output = model.generate(
inputs,
pad_token_id=tokenizer.eos_token_id,
logits_processor=[syncode_logits_processor],
return_dict_in_generate=True,
output_scores=True,
**args,
)
outputs = generation_output.sequences
outputs = [x[len(inputs[0]):] for x in outputs]
completions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
completions_tokens = [[tokenizer.decode(tok) for tok in output] for output in outputs]
for i, (comp, toks) in enumerate(zip(completions, completions_tokens)):
print(f">>-- COMPLETION {i} -->>")
print(comp)
print(f">------- TOKENS -------<")
print(toks)
print(f"<<--------------------<<")
try:
tree = parser.parse(comp)
print(tree.pretty())
print("*** CAN PARSE")
except Exception as e:
trace = traceback.format_exc()
print(trace)
print("*** CANNOT PARSE")
My hunch is that the lexer processing is handling the regex for quoted_str incorrectly.
Please also try to run commenting the line indicated above (disabling the update_expr non-terminal) and see that it works in that case.
Hope you can look into this soon. Thanks.