Missing quotation marks in lists of multiple choice values for output type
When a multiple choice (either through Literal or Enum) is contained in a List as an output type, the value of the multiple choice is missing quotation marks. The problem does not exist with only a multiple choice as an output type as the llm response is a string.
Example to illustrate the problem and to reproduce:
from typing import List, Literal
import outlines
import transformers
from outlines.types.dsl import to_regex, python_types_to_terms
from enum import Enum
TEST_MODEL = "microsoft/Phi-3-mini-4k-instruct"
model = outlines.from_transformers(
transformers.AutoModelForCausalLM.from_pretrained(TEST_MODEL),
transformers.AutoTokenizer.from_pretrained(TEST_MODEL),
)
# multiple choices (same issue with Enum as with Literal)
output_type = List[Literal["Paris", "London", "Rome", "Berlin"]]
print(to_regex(python_types_to_terms(output_type))) # \[(Paris|London|Rome|Berlin)(,\ (Paris|London|Rome|Berlin))*\]
result = model("Give me a list of cities.", output_type, max_new_tokens=100)
print(result) # [Paris]
# string
output_type = List[str]
print(to_regex(python_types_to_terms(output_type))) # \[("[^"]*")(,\ ("[^"]*"))*\]
result = model("Give me a list of cities.", output_type, max_new_tokens=100)
print(result) # ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"]
I think there may be similar issue of strings in complex output types elsewhere.
Hi, I propose to work on this.
Awesome! Don't hesitate to reach out if you have questions.
Ok, so the issue is caused at the transformers lib level (so outside of outlines' scope). You can see the method called here in outlines, which calls batch_decode in transformers, and then decode.
Because of this, either we can come up with a fix in the transformers lib and ask them to merge it, although that it may be a breaking change (?). Or we can come up with our own post-generation function. I made a PR for the second solution, up for discussion to validate that the result is what we want.
Regarding the "self-formatting" solution, I wrote a short regex function to do it; however, I ran into a wall when trying to use the output_type for formatting heuristic.
Indeed, we pass a typing.List[typing.Literal['Paris', 'London', 'Rome', 'Berlin']] type, but the generate() function expects a Optional[OutlinesLogitsProcessor], so the modification is done in the format_output_type before the generate function is called.
Fixing this requires changing the base abstract class, so I'm not fully sure that's a modification we'd want to make. Maybe an optional method to format post-generation the output would be lighter? Such as :
output_type = List[Literal["Paris", "London", "Rome", "Berlin"]]
result = model("Give me a list of cities.", output_type, max_new_tokens=100)
formated_result = model.format_result(result, output_types)
I'm not sure I follow you. To me the issue is a lot more straightforward: we do not generate the correct regex for the output type List[Literal["Paris", "London", "Rome", "Berlin"]] and similar ones.
The regex currently created by Outlines is: \[(Paris|London|Rome|Berlin)(,\ (Paris|London|Rome|Berlin))*\], notice there is nothing related to quotation marks around values in the list. So it's only logical the model would yield something like [Paris].
Now, if you provide as an output type Regex(r'\[("(Paris|London|Rome|Berlin)"(,\ "(Paris|London|Rome|Berlin)")*)\]') (so, with quotation marks), you do get the expected results such as ["Paris", "Berlin", "Rome"].
So the solution to the issue is modifying either python_types_to_terms or to_regex to make sure we generate the correct regex from which to build the logits processor.
Ok, thanks, I hadn't noticed that there was a pre-processing of output_types in outlines/generator.py - I guess the hazards of crashing into a complex lib !
So the fix is much simpler - I updated the code in https://github.com/dottxt-ai/outlines/pull/1704/commits/2e214af0ea5989518c58e676da2f9629dcab3783.
Here is the output for provided example in OP:
# Using a List[Literal]:
\[("Paris"|"London"|"Rome"|"Berlin")(",\ "("Paris"|"London"|"Rome"|"Berlin"))*\]
["Paris"]
# Using a List[str]
\[("[^"]*")(",\ "("[^"]*"))*\]
["New York City"]
Since the term is structured like this:
typing.List[typing.Literal['Paris', 'London', 'Rome', 'Berlin']]
└── Sequence
├── String('[')
├── Alternatives(|)
│ ├── String('Paris')
│ ├── String('London')
│ ├── String('Rome')
│ └── String('Berlin')
├── KleeneStar(*)
│ └── Sequence
│ ├── String(', ')
│ └── Alternatives(|)
│ ├── String('Paris')
│ ├── String('London')
│ ├── String('Rome')
│ └── String('Berlin')
└── String(']')
I had to exclude square brackets from getting quotation marks.
i looked more into it and it's really quite tricky. We should probably solve that in the context of a wider refactoring of the output type system.