outlines
outlines copied to clipboard
Implement prompt/generation alignment
[updated 2024-06-28]
The aim of this PR to implement prompt token alignment
The idea is to modify the states_to_token_maps of the Guide to include in it the characters of some of the last tokens of the prompt that could be replaced by a different token that contains the same set of characters plus characters for the generation (a crossing token).
To do so, when receiving the prompts of the user (so after the OutlinesLogitsProcessor has already been initialized with its FSM), we copy the FSM as many as their are prompts and we apply to each of them prompt token alignment (as the modification of the states_to_token_maps depends on the content of each prompt).
At the end of the process, we modify the generated sequences to remove the characters at the beginning that correspond to the ends of the user prompts
This is not intended to be merged, I was just wondering whether you think this is a promising direction to look into
I think this is the right general direction.
- The case in which the text after the "boundary" of a token matching the end of the prompt does not exist in the vocabulary by itself is not covered
Could you illustrate this? I had a PR opened (can't find it right now) where I iterated once over the vocabulary to find the overlapping tokens.
Could you illustrate this? I had a PR opened (can't find it right now) where I iterated once over the vocabulary to find the overlapping tokens.
Making up a fake example. My prompt is "Good mor". Let's say there's a token for "mor" and it's the last one of the prompt. We would want token alignment to replace "mor" with "morning". However, if the token "ning" by itself does not exist, then there's nothing in the states_to_token_maps that correspond to it as, at this point, the character-based FSM that would allow to generate "ning" has already been turned into a token-based mapping.
I was looking at creating states_to_token_maps only after the call is made (and the FSM is updated) but that would add too much overhead.
I was then thinking that a solution could be to create at initialization a mapping that contains both information about characters and about tokens (so we would have some states with no tokens leading to them that would be used for the token alignement)
How about looping over the entire vocabulary and store the tokens that accept mor as a prefix. Then, in the unconstrained case the first state of the FSM would have transitions to the overlapping tokens only?
Haven't taken the time to think about the constrained case yet.
I had not realized that I could walk the states_to_token_maps character by character for the postfix part of the crossing tokens in the constrained case. I think it works with almost no additional overhead like that. Let me know if you think it's fine and I'll update the tests afterward
Yes I think that's the right approach. There's some stuff to figure out in terms of design, but otherwise looks good.
I'll write unit tests next if you think having those separate functions is the right design
We're getting really close. There are a few design changes remaining, and mostly we should have comprehensive tests before merging.
I rebased your branch on main after a big refactor of the FSM interface. I will take a closer look this week.
Is this still something we want to work on?
Yes! I'm currently thinking about how we could integrate that to the logits processors since most integration are going to use this :)
Sorry to prod but please don't lose sight of this! I think this is a very important change to make Outlines the most competitive structured generation system
I think it is time to revisit this as #966 is about to be merged and the custom sampling loop will be removed. We can still implement this via passing logit processors to downstream libraries. Effectively we will be adding this feature to every upstream library :)
@RobinPicard are you still interested in implementing this?
@RobinPicard are you still interested in implementing this?
I can look at adapting it to the change made this week end
@RobinPicard are you still interested in implementing this?
I can look at adapting it to the change made this week end
That's great news!
Please let me know if you run into any issues or have any questions about OutlinesLogitsProcessor.
You probably want to branch from https://github.com/outlines-dev/outlines/pull/966 since it has fixes to the logits processors and a more detailed docstring.
To make sure I understand the wider context, the plan is to eventually remove SequenceGenerator and only use SequenceGeneratorAdapter @lapp0, right? If so, should we implement it for both of those or only the latter?
Indeed
I rebased on your branch and modified my initial commit @lapp0
Could you rebase on main now that #966 was merged?
This issue is causing problems for the PR. If we don't have an explanation/solution for it, we would have to modify the logic related to FSMLogitsProcessor._fsm_states. I don't know yet what we could replace it with if the order of sequences is not maintained + the values of previous tokens can change though.
@RobinPicard per my comment in the linked issue, it appears that transformers beam search submits an unused sequence group to logits processors during the final generation step. Is this still an issue if it only occurs on the last step and it's not actually part of generation?
Please let me know how I can help.
It's fine if it only happens at the final generation step, I simply added a try:... except KeyError:... block
Glad that it's not blocking, please let me know if you run into any other issues or have any questions!
I don't have more questions, I would be interested by a review of the PR though!
I don't get how step 2 would work. Are the unhealed tokens passed down to the original Guide just the user prompt?
Another design I'm considering for better separation of token healing and Guides is to create a dedicated class and have Guides that implement token healing inherit from it on top of GuideProtocol. Do you think that would be better?
I don't get how step 2 would work. Are the unhealed tokens passed down to the original Guide just the user prompt?
Another design I'm considering for better separation of token healing and Guides is to create a dedicated class and have Guides that implement token healing inherit from it on top of GuideProtocol. Do you think that would be better?
Here's roughly what I'm thinking
In SequenceGeneratorAdapter
- set the prompt as
alignment_guide.prompt_prefix - then call
model.generate, - then left-truncate generations
In AlignmentGuide
- Constructor:
AlignmentGuide(prompt, tokenizer, child_guide=None)- Let
self.start_generation_tokensbe the set of legal starting tokens forchild_guide(or all tokens if no guide) - Tokenize (prompt, token) for token in
start_generation_tokensand determine the "longest commonself.prompt_prefix" and the set of legalself.prompt_suffix_ids self.initial_state = AlignmentGuideState(legal_paths=legal_start_generation_tokens, child_guide_state=child_guide.initial_state)
- Let
- For
get_next_instruction(state)- If
legal_pathsNone, alignment is complete. Defer to the child_guide. - If
legal_pathsis not None, tokens =[path[0] for path in state.legal_paths]
- If
- For
get_next_state(state, token_id)- If legal paths: filter
state.legal_pathssuch that it only includes those starting withtoken_id- If generating the final token in a path, pass the unaligned token to the child_guide, return `AlignmentGuideState(legal_paths=none, child_guide_state=new_child_guide_state)
- If legal paths: filter
- If no
legal_paths is None: Use the child guide to updatechild_guide_state
This design requires constructing a Guide for each prompt, however this is necessary because the behavior of the Guide varies prompt-to-prompt.
Details surrounding how to pre-process and manage state for AlignmentGuide may vary, this is just a rough outline.
Please let me know if this makes sense, or if I'm missing something.
I think I understand the general idea. The main issue I have concerns
- If generating the final token in a path, pass the unaligned token to the child_guide
I don't get what the unaligned token means in this context.
I don't get what the unaligned token means in this context.
Sorry I wasn't clear here.
For example, if the prompt is "hello wo" and we truncate it to "hello" for alignment, and the model generates the token " world" as the next output, "rld" is what needs to be passed to the child guide to continue processing.
This allows for alignment to be compatible with any guide with minimal changes.
Ah I see, but I thought the problem is that "rld" may not exist as a token in the vocabulary so it would not be found in the states_to_tokens_map. Or should we walk through the states_to_tokens_map of the child guide character by character in the get_next_state method of AlignmentGuide when the generation reached the crossing token?
Sorry for the delayed response.
but I thought the problem is that "rld" may not exist as a token in the vocabulary so it would not be found in the states_to_tokens_map.
If I'm understanding the problem correctly, to mitigate this we need to determine the "longest common prompt prefix". This will allow any legal token to be generated as a "pseudo-token".
Or should we walk through the states_to_tokens_map of the child guide character by character in the get_next_state method of AlignmentGuide when the generation reached the crossing token?
Can we precompute this when the guide is constructed?