Regex ignoring prohibited characters
The bug I have defined a regex that prohibits certain characters from generating using "[^]" expression. Guidance still generates these characters, and even seems to be biased towards them.
To Reproduce Give a full working code snippet that can be pasted into a notebook cell or python file. Make sure to include the LLM load step so we know which model you are using.
import guidance
from guidance import models, user, system, assistant, capture, substring, zero_or_more, optional
import json
lm = models.transformers.LlamaChat(model=model, tokenizer=tokenizer, temperature=0.0, echo=False) # model is a qlora fine-tuned huggingface model.
@guidance(stateless=True)
def gen_number(lm):
return lm + gen(regex=r'-?\d+\.?\d*', stop=',')
@guidance(stateless=True)
def gen_string(lm):
return lm + '"' + gen(stop='"') + '"'
@guidance(stateless=True)
def nullable(lm, generator):
return lm + select([generator, "null"])
@guidance(stateless=True)
def gen_date(lm, prefix="20"):
return lm + '"' + prefix + gen(regex=r"\d\d-\d\d-\d\d") + '"'
@guidance(stateless=True)
def gen_list(lm):
return lm + "[" + gen(regex=r'(("[^>"]*", )*("[^>"]*"))?') + "]"
# @guidance(stateless=True)
# def gen_list(lm):
# return lm + "[" + gen(stop="]") + "]"
@guidance(stateless=True)
def gen_row(lm):
out = lm + "{" \
'"description": ' + nullable(gen_string()) + ", " \
'"value": ' + nullable(gen_number()) + ", " \
'"face_amount": ' + nullable(gen_number()) + ", " \
'"cost": ' + nullable(gen_number()) + ", " \
'"currency_code": ' + nullable(gen_string()) + ", " \
'"quantity": ' + nullable(gen_number()) + ", " \
'"maturity_date_min": ' + nullable(gen_date()) + ", " \
'"maturity_date_max": ' + nullable(gen_date()) + ", " \
'"call_date": ' + nullable(gen_date()) + ", " \
'"coupon_min": ' + nullable(gen_number()) + ", " \
'"coupon_max": ' + nullable(gen_number()) + ", " \
'"yield_min": ' + nullable(gen_number()) + ", " \
'"yield_max": ' + nullable(gen_number()) + ", " \
'"spread": ' + nullable(gen_number()) + ", " \
'"percent_net_assets": ' + nullable(gen_number()) + ", " \
'"footnotes": ' + gen_list() + "}"
return out
@guidance(stateless=False)
def format_row(lm, row_json):
with user():
lm += row_json
with assistant():
lm += capture(gen_row(), "formatted")
return lm
prediction = (lm + format_row(my_json))["formatted"] # my_json is the input
I am not able to share my fine-tuned model or an example input unfortunately.
Despite telling the regex not to generate ">" characters as list elements in gen_list, I will still get a list like [">", "a"] after "footnotes". Another odd thing is that when I switch to using the commented out gen_list above (which doesn't use a regex), the model never generates ">" characters. The only issues I run into with this approach is that sometimes the model will generate numbers without quotes around them, which is why I'm trying to further constrain the generation.
In my training data, there is never a ">" character in the list, and this character isn't in the prompt. I have also noticed a significant bias towards generating other "symbol" characters like ":", "<", ".", "]", "[" and a few others. Sometimes it will also incorrectly combine a symbol with the correct value like ":a", instead of "a". Again, the training data never shows such examples, and when I use the alternate gen_list function, it doesn't do this. It seems like using the regex option specifically biases the model to generate these characters, even though it should be forbidden.
System info (please complete the following information):
- OS (e.g. Ubuntu, Windows 11, Mac OS, etc.): Ubuntu
- Guidance Version (
guidance.__version__): 0.1.10
hmm...I tried to reproduce this with mistral 7b and could not get violations. Can you try to reproduce with a model we can both use? thanks!
See also: https://github.com/guidance-ai/guidance/issues/710