mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Using token sequences as stop criteria does not work in mlx_lm

Open tidely opened this issue 5 months ago • 7 comments

The implementation of stop_criteria in mlx_lm.server is inherently flawed. Stop sequences only get matched when the newest tokens generated perfectly match a stop sequence. However it does not stop if the stop sequence is inside of the tokens in any other way.

This only checks if the newest tokens perfectly match a sequence

for stop_ids in stop_id_sequences:
    if len(tokens) >= len(stop_ids):
        if np.array_equal(tokens[-len(stop_ids) :], stop_ids):
            return StopCondition(stop_met=True, trim_length=len(stop_ids))

stopping_criteria only gets called when max_stop_id_sequence_len amount of tokens have been generated, which is the length of the longest stop sequence.

if len(stop_sequence_buffer) > max_stop_id_sequence_len:
      if REPLACEMENT_CHAR in _tokenizer.decode(token):
          continue
      stop_condition = stopping_criteria(
          tokens,
          stop_id_sequences,
          eos_token_id,
      )

Example where it breaks:

I have two stop sequences, one is of length 4, and one of length 6. Once 6 tokens have been generated, stop_criteria is called. However the tokens I have generated only match the stop sequence of length 4, and the match happens at the start of the new tokens, not at the end. However since stop_criteria only checks the end of the full token list, it does not get matched and generation does not stop.

stop sequence = [1, 2, 3, 4]

new tokens = [1, 2, 3, 4, 5, 6]

# stop_criteria get called and checks for len(stop_sequence) of tokens at the end of new tokens

[3, 4, 5, 6] != [1, 2, 3, 4]

tidely avatar Mar 03 '24 15:03 tidely

Additionally the condition for checking the stop criteria is:

if len(stop_sequence_buffer) > max_stop_id_sequence_len:

It is not inclusive, meaning there will always be an extra token appended. Meaning no stop sequences can ever be matched.

tidely avatar Mar 04 '24 13:03 tidely

Additionally the condition for checking the stop criteria is:

if len(stop_sequence_buffer) > max_stop_id_sequence_len:

It is not inclusive, meaning there will always be an extra token appended. Meaning no stop sequences can ever be matched.

Just curious, how did you find the issue? I ran a few tests before and didn't see any extra tokens added.

mzbac avatar Mar 04 '24 13:03 mzbac

I'll try to sum it up

max_stop_id_sequence_len is the length of the longest stop id sequence. Now let's assume the buffer is the same size as the longest stop sequence.

if len(stop_sequence_buffer) > max_stop_id_sequence_len:

Since this check is non-inclusive, it would loop one extra time before running what comes after the if statement. Meaning the length of the buffer would now be 1 larger than the longest stop sequence.

Now when we take into account, that the stop_criteria function only checks for perfect matches, where the tail of the "tokens" matches a stop sequence. It can no longer ever match, because the extra tokens was generated before calling stop_criteria.

tidely avatar Mar 04 '24 14:03 tidely

I see, ~it would only happen when the model starts with the stop word. Maybe that's why it wasn't picked up by my testing.~ This edge case is a bit difficult to pick up with manual testing, but it would be more obvious when the model starts with a stop word.

mzbac avatar Mar 04 '24 14:03 mzbac

Yeah, I think it was intended that it runs on every token generated, and we can throw away the buffer entirely. This would address most issues.

It still needs a special case for streaming, since we need to anticipate a stop word, or we might stream parts of it. I have a prototype for that check currently, but still doing some testing.

tidely avatar Mar 04 '24 15:03 tidely

Yeah, you are right. The original implementation didn't have a buffer, so it ended up sending the stop word back to the client in the streaming. The buffer was introduced to solve that issue, but it seems like it wasn't well thought out in the implementation.

mzbac avatar Mar 04 '24 15:03 mzbac

Probably, I have kind of the same problem. The 'generate' function outputs a single key per token, here is some pseudocode for the problem:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

temperature = 0.2
prompt = "Extract in JSON format: "

for token in generate(model, tokenizer, prompt, temp, max_tokens, verbose=True):

    generated_tokens.append(tokenizer.encode(token))

text = "Here is the extracted text"
generated_tokens = [[220], [39], [68], [81], [68], [220], [72], [82], [220], [83], [71], [68], [220], [68], [87], [83], [81], [64], [66], [83], [68], [67], [220], [83], [68], [87], [83]]

how_it_should_be = tokenizer.encode(text)

output = [8586, 374, 279, 28532, 1495]

Is there an easier way /example to retrieve the produced tokens per generated token such as 8586 or 374, and use those for stop criteria? Ideally multiple tokens as a stop criteria would be best because of stopping synthetic JSON generation just as it finishes.

s-smits avatar Apr 21 '24 18:04 s-smits