mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Predict stop sequence matches during streaming

Open tidely opened this issue 5 months ago • 5 comments

When streaming using mlx_lm/server.py we should predict potential stop sequence matches, and generate tokens until we know that there is no match. This prevents the server from sending parts of a stop sequence to the client before it finds the match.

Fixes #524

My implementation adds a new function called "sequence_overlap" which checks how much sequence 1 has overlap with sequence 2. It checks for larger overlaps first, and returns the overlap as an integer.

The server checks for overlaps, and generates more tokens before allowing the server to send them.

if any((sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)):
    continue

The sequence_overlap implementation can be tested with this example:

from typing import Sequence


def sequence_overlap(s1: Sequence, s2: Sequence) -> int:
    """
    Check how much overlap two sequences have.
        Only checks the end of s1 overlapping the start of s2

    Args:
        s1 (Sequence): The first sequence, which end is checked
        s2 (Sequence): The second sequence, which beginning is checked

    Returns:
        int: The amount of overlap between s1 and s2
    """
    # Count down from the length of the smaller list -> Checks for larger overlaps first
    for index in range(min(len(s1), len(s2)), 0, -1):
        # Check if they have index amount of overlap
        if s1[-index:] == s2[:index]:
            return index
    return 0


stop_sequence = [27, 28, 29]

tokens = []
new_tokens = []

for token in range(50):
    tokens.append(token)
    new_tokens.append(token)

    # This should always be the first check, since it needs to be performed on every token
    if len(tokens) >= len(stop_sequence) and tokens[-len(stop_sequence):] == stop_sequence:
        print("Contains stop sequence:", new_tokens)
        tokens = tokens[:len(tokens) - len(stop_sequence)]
        new_tokens.clear()
        break

    # Generate tokens until we know that tokens does not contain stop sequence
    if sequence_overlap(tokens, stop_sequence):
        print("Found a possible start to a stop sequence:", new_tokens)
        continue

    # Process new tokens
    print("Processing new tokens:", new_tokens)
    new_tokens.clear()

# In the case that the generation ends with the start of a stop sequence
# We need to process leftovers, since it would call continue until a break
if new_tokens:
    print("Processing leftover new tokens", new_tokens)
    new_tokens.clear()

print("Full sequence:", tokens)

tidely avatar Mar 06 '24 16:03 tidely

Oh much better, thank you 😄

awni avatar Mar 06 '24 16:03 awni

What about adding your test (modified for unittest) as a test case to a new test file test_server.py in the tests directory: https://github.com/ml-explore/mlx-examples/tree/main/llms/tests ?

awni avatar Mar 14 '24 04:03 awni

Yeah, I'll have a look into writing a unittest

tidely avatar Mar 14 '24 08:03 tidely

Hi! Is this ready to be reviewed again?

awni avatar Mar 20 '24 04:03 awni

@awni yes

tidely avatar Mar 20 '24 18:03 tidely