syncode icon indicating copy to clipboard operation
syncode copied to clipboard

Incorrect constrained generation for regex

Open teoremma opened this issue 1 year ago • 0 comments

In the following example, the constrained model fails to generate a string that matches the specified grammar:

import traceback

import lark
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from syncode import Grammar, SyncodeLogitsProcessor

grammar_str = r"""
start: print_expr
// COMMENT THE LINE BELOW AND IT WILL WORK
     | update_expr

print_expr: "print" "(" field_access ("," field_access)* ")"

update_expr: field_access "=" quoted_str

field_access: "movie" "[" field_str "]"

field_str: "'" field_name "'"
         | "\"" field_name "\""

field_name: "title"
          | "director"
          | "cast"
          | "country"
          | "date_added"
          | "release_year"
          | "rating"
          | "duration"

quoted_str: "'" /[^']+/ "'"
          | "\"" /[^"]+/ "\""

%ignore " "
"""

prompt = "Given the `movie` record, print the name, the year it was premiered, and its runtime.\n"

model_name = "deepseek-ai/deepseek-coder-1.3b-base"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="balanced_low_0")
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
args = {
    "max_new_tokens": 128,
    "do_sample": False,
    "num_beams": 1,
    "num_return_sequences": 1,
}

parser = lark.Lark(grammar_str)

inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

syncode_grammar = Grammar(grammar_str)
syncode_logits_processor = SyncodeLogitsProcessor(
    grammar=syncode_grammar,
    tokenizer=tokenizer,
    parse_output_only=True,
    num_samples=args["num_beams"],
    mode="grammar_strict",
)
syncode_logits_processor.reset(prompt)

generation_output = model.generate(
    inputs,
    pad_token_id=tokenizer.eos_token_id,
    logits_processor=[syncode_logits_processor],
    return_dict_in_generate=True,
    output_scores=True,
    **args,
)

outputs = generation_output.sequences
outputs = [x[len(inputs[0]):] for x in outputs]
completions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
completions_tokens = [[tokenizer.decode(tok) for tok in output] for output in outputs]

for i, (comp, toks) in enumerate(zip(completions, completions_tokens)):
    print(f">>-- COMPLETION {i} -->>")
    print(comp)
    print(f">------- TOKENS -------<")
    print(toks)
    print(f"<<--------------------<<")
    try:
        tree = parser.parse(comp)
        print(tree.pretty())
        print("*** CAN PARSE")
    except Exception as e:
        trace = traceback.format_exc()
        print(trace)
        print("*** CANNOT PARSE")

My hunch is that the lexer processing is handling the regex for quoted_str incorrectly. Please also try to run commenting the line indicated above (disabling the update_expr non-terminal) and see that it works in that case.

Hope you can look into this soon. Thanks.

teoremma avatar Oct 04 '24 01:10 teoremma