outlines icon indicating copy to clipboard operation
outlines copied to clipboard

Implement prompt/generation alignment

Open RobinPicard opened this issue 1 year ago • 29 comments

[updated 2024-06-28]

The aim of this PR to implement prompt token alignment

The idea is to modify the states_to_token_maps of the Guide to include in it the characters of some of the last tokens of the prompt that could be replaced by a different token that contains the same set of characters plus characters for the generation (a crossing token).

To do so, when receiving the prompts of the user (so after the OutlinesLogitsProcessor has already been initialized with its FSM), we copy the FSM as many as their are prompts and we apply to each of them prompt token alignment (as the modification of the states_to_token_maps depends on the content of each prompt).

At the end of the process, we modify the generated sequences to remove the characters at the beginning that correspond to the ends of the user prompts

RobinPicard avatar Jan 11 '24 23:01 RobinPicard

This is not intended to be merged, I was just wondering whether you think this is a promising direction to look into

I think this is the right general direction.

  • The case in which the text after the "boundary" of a token matching the end of the prompt does not exist in the vocabulary by itself is not covered

Could you illustrate this? I had a PR opened (can't find it right now) where I iterated once over the vocabulary to find the overlapping tokens.

rlouf avatar Jan 27 '24 09:01 rlouf

Could you illustrate this? I had a PR opened (can't find it right now) where I iterated once over the vocabulary to find the overlapping tokens.

Making up a fake example. My prompt is "Good mor". Let's say there's a token for "mor" and it's the last one of the prompt. We would want token alignment to replace "mor" with "morning". However, if the token "ning" by itself does not exist, then there's nothing in the states_to_token_maps that correspond to it as, at this point, the character-based FSM that would allow to generate "ning" has already been turned into a token-based mapping.

I was looking at creating states_to_token_maps only after the call is made (and the FSM is updated) but that would add too much overhead.

I was then thinking that a solution could be to create at initialization a mapping that contains both information about characters and about tokens (so we would have some states with no tokens leading to them that would be used for the token alignement)

RobinPicard avatar Jan 27 '24 10:01 RobinPicard

How about looping over the entire vocabulary and store the tokens that accept mor as a prefix. Then, in the unconstrained case the first state of the FSM would have transitions to the overlapping tokens only?

Haven't taken the time to think about the constrained case yet.

rlouf avatar Jan 27 '24 10:01 rlouf

I had not realized that I could walk the states_to_token_maps character by character for the postfix part of the crossing tokens in the constrained case. I think it works with almost no additional overhead like that. Let me know if you think it's fine and I'll update the tests afterward

RobinPicard avatar Jan 30 '24 09:01 RobinPicard

Yes I think that's the right approach. There's some stuff to figure out in terms of design, but otherwise looks good.

rlouf avatar Feb 10 '24 21:02 rlouf

I'll write unit tests next if you think having those separate functions is the right design

RobinPicard avatar Feb 17 '24 12:02 RobinPicard

We're getting really close. There are a few design changes remaining, and mostly we should have comprehensive tests before merging.

rlouf avatar Mar 01 '24 12:03 rlouf

I rebased your branch on main after a big refactor of the FSM interface. I will take a closer look this week.

rlouf avatar Mar 11 '24 21:03 rlouf

Is this still something we want to work on?

RobinPicard avatar Apr 12 '24 08:04 RobinPicard

Yes! I'm currently thinking about how we could integrate that to the logits processors since most integration are going to use this :)

rlouf avatar Apr 12 '24 08:04 rlouf

Sorry to prod but please don't lose sight of this! I think this is a very important change to make Outlines the most competitive structured generation system

shawnz avatar May 30 '24 16:05 shawnz

I think it is time to revisit this as #966 is about to be merged and the custom sampling loop will be removed. We can still implement this via passing logit processors to downstream libraries. Effectively we will be adding this feature to every upstream library :)

@RobinPicard are you still interested in implementing this?

rlouf avatar Jun 19 '24 09:06 rlouf

@RobinPicard are you still interested in implementing this?

I can look at adapting it to the change made this week end

RobinPicard avatar Jun 20 '24 05:06 RobinPicard

@RobinPicard are you still interested in implementing this?

I can look at adapting it to the change made this week end

That's great news!

Please let me know if you run into any issues or have any questions about OutlinesLogitsProcessor.

You probably want to branch from https://github.com/outlines-dev/outlines/pull/966 since it has fixes to the logits processors and a more detailed docstring.

lapp0 avatar Jun 20 '24 10:06 lapp0

To make sure I understand the wider context, the plan is to eventually remove SequenceGenerator and only use SequenceGeneratorAdapter @lapp0, right? If so, should we implement it for both of those or only the latter?

RobinPicard avatar Jun 25 '24 09:06 RobinPicard

Indeed

rlouf avatar Jun 27 '24 00:06 rlouf

I rebased on your branch and modified my initial commit @lapp0

RobinPicard avatar Jun 28 '24 00:06 RobinPicard

Could you rebase on main now that #966 was merged?

rlouf avatar Jul 16 '24 13:07 rlouf

This issue is causing problems for the PR. If we don't have an explanation/solution for it, we would have to modify the logic related to FSMLogitsProcessor._fsm_states. I don't know yet what we could replace it with if the order of sequences is not maintained + the values of previous tokens can change though.

RobinPicard avatar Jul 19 '24 13:07 RobinPicard

@RobinPicard per my comment in the linked issue, it appears that transformers beam search submits an unused sequence group to logits processors during the final generation step. Is this still an issue if it only occurs on the last step and it's not actually part of generation?

Please let me know how I can help.

lapp0 avatar Jul 19 '24 15:07 lapp0

It's fine if it only happens at the final generation step, I simply added a try:... except KeyError:... block

RobinPicard avatar Jul 19 '24 16:07 RobinPicard

Glad that it's not blocking, please let me know if you run into any other issues or have any questions!

lapp0 avatar Jul 19 '24 17:07 lapp0

I don't have more questions, I would be interested by a review of the PR though!

RobinPicard avatar Jul 19 '24 19:07 RobinPicard

I don't get how step 2 would work. Are the unhealed tokens passed down to the original Guide just the user prompt?

Another design I'm considering for better separation of token healing and Guides is to create a dedicated class and have Guides that implement token healing inherit from it on top of GuideProtocol. Do you think that would be better?

RobinPicard avatar Jul 23 '24 06:07 RobinPicard

I don't get how step 2 would work. Are the unhealed tokens passed down to the original Guide just the user prompt?

Another design I'm considering for better separation of token healing and Guides is to create a dedicated class and have Guides that implement token healing inherit from it on top of GuideProtocol. Do you think that would be better?

Here's roughly what I'm thinking

In SequenceGeneratorAdapter

  • set the prompt as alignment_guide.prompt_prefix
  • then call model.generate,
  • then left-truncate generations

In AlignmentGuide

  • Constructor: AlignmentGuide(prompt, tokenizer, child_guide=None)
    • Let self.start_generation_tokens be the set of legal starting tokens for child_guide (or all tokens if no guide)
    • Tokenize (prompt, token) for token in start_generation_tokens and determine the "longest common self.prompt_prefix" and the set of legal self.prompt_suffix_ids
    • self.initial_state = AlignmentGuideState(legal_paths=legal_start_generation_tokens, child_guide_state=child_guide.initial_state)
  • For get_next_instruction(state)
    • If legal_paths None, alignment is complete. Defer to the child_guide.
    • If legal_paths is not None, tokens = [path[0] for path in state.legal_paths]
  • For get_next_state(state, token_id)
    • If legal paths: filter state.legal_paths such that it only includes those starting with token_id
      • If generating the final token in a path, pass the unaligned token to the child_guide, return `AlignmentGuideState(legal_paths=none, child_guide_state=new_child_guide_state)
  • If no legal_paths is None: Use the child guide to update child_guide_state

This design requires constructing a Guide for each prompt, however this is necessary because the behavior of the Guide varies prompt-to-prompt.

Details surrounding how to pre-process and manage state for AlignmentGuide may vary, this is just a rough outline.

Please let me know if this makes sense, or if I'm missing something.

lapp0 avatar Jul 26 '24 18:07 lapp0

I think I understand the general idea. The main issue I have concerns

  • If generating the final token in a path, pass the unaligned token to the child_guide

I don't get what the unaligned token means in this context.

RobinPicard avatar Aug 02 '24 10:08 RobinPicard

I don't get what the unaligned token means in this context.

Sorry I wasn't clear here.

For example, if the prompt is "hello wo" and we truncate it to "hello" for alignment, and the model generates the token " world" as the next output, "rld" is what needs to be passed to the child guide to continue processing.

This allows for alignment to be compatible with any guide with minimal changes.

lapp0 avatar Aug 02 '24 14:08 lapp0

Ah I see, but I thought the problem is that "rld" may not exist as a token in the vocabulary so it would not be found in the states_to_tokens_map. Or should we walk through the states_to_tokens_map of the child guide character by character in the get_next_state method of AlignmentGuide when the generation reached the crossing token?

RobinPicard avatar Aug 02 '24 14:08 RobinPicard

Sorry for the delayed response.

but I thought the problem is that "rld" may not exist as a token in the vocabulary so it would not be found in the states_to_tokens_map.

If I'm understanding the problem correctly, to mitigate this we need to determine the "longest common prompt prefix". This will allow any legal token to be generated as a "pseudo-token".

Or should we walk through the states_to_tokens_map of the child guide character by character in the get_next_state method of AlignmentGuide when the generation reached the crossing token?

Can we precompute this when the guide is constructed?

lapp0 avatar Aug 12 '24 19:08 lapp0