mlx-examples
mlx-examples copied to clipboard
Predict stop sequence matches during streaming
When streaming using mlx_lm/server.py we should predict potential stop sequence matches, and generate tokens until we know that there is no match. This prevents the server from sending parts of a stop sequence to the client before it finds the match.
Fixes #524
My implementation adds a new function called "sequence_overlap" which checks how much sequence 1 has overlap with sequence 2. It checks for larger overlaps first, and returns the overlap as an integer.
The server checks for overlaps, and generates more tokens before allowing the server to send them.
if any((sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)):
continue
The sequence_overlap implementation can be tested with this example:
from typing import Sequence
def sequence_overlap(s1: Sequence, s2: Sequence) -> int:
"""
Check how much overlap two sequences have.
Only checks the end of s1 overlapping the start of s2
Args:
s1 (Sequence): The first sequence, which end is checked
s2 (Sequence): The second sequence, which beginning is checked
Returns:
int: The amount of overlap between s1 and s2
"""
# Count down from the length of the smaller list -> Checks for larger overlaps first
for index in range(min(len(s1), len(s2)), 0, -1):
# Check if they have index amount of overlap
if s1[-index:] == s2[:index]:
return index
return 0
stop_sequence = [27, 28, 29]
tokens = []
new_tokens = []
for token in range(50):
tokens.append(token)
new_tokens.append(token)
# This should always be the first check, since it needs to be performed on every token
if len(tokens) >= len(stop_sequence) and tokens[-len(stop_sequence):] == stop_sequence:
print("Contains stop sequence:", new_tokens)
tokens = tokens[:len(tokens) - len(stop_sequence)]
new_tokens.clear()
break
# Generate tokens until we know that tokens does not contain stop sequence
if sequence_overlap(tokens, stop_sequence):
print("Found a possible start to a stop sequence:", new_tokens)
continue
# Process new tokens
print("Processing new tokens:", new_tokens)
new_tokens.clear()
# In the case that the generation ends with the start of a stop sequence
# We need to process leftovers, since it would call continue until a break
if new_tokens:
print("Processing leftover new tokens", new_tokens)
new_tokens.clear()
print("Full sequence:", tokens)
Oh much better, thank you 😄
What about adding your test (modified for unittest) as a test case to a new test file test_server.py
in the tests directory: https://github.com/ml-explore/mlx-examples/tree/main/llms/tests ?
Yeah, I'll have a look into writing a unittest
Hi! Is this ready to be reviewed again?
@awni yes