Weirdness with tokenization in Phi-3
Server:
toolio_server --model=mlx-community/Phi-3-mini-128k-instruct-4bit
Client:
toolio_request --apibase="http://localhost:8000" --prompt='What is the average airspeed of an unladen swallow?'
You can run the above any number of times, but as soon as you run a version that tries to use a prior prompt cache:
toolio_request --apibase="http://localhost:8000" --prompt='What is the average airspeed of an unladen swallow? Where have I heard that before?'
It blows up. Server exception tail:
File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/cli/server.py", line 271, in post_v1_chat_completions_impl
for result in app.state.model.completion(
File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/schema_helper.py", line 296, in completion
logits, cache = self._evaluate_prompt(
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/schema_helper.py", line 92, in _evaluate_prompt
logits = self.model(mx.array(tokens)[None], cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 202, in __call__
out = self.model(inputs, cache)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 184, in __call__
h = layer(h, mask, c)
^^^^^^^^^^^^^^^^^
File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 148, in __call__
r = self.self_attn(self.input_layernorm(x), mask, cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 110, in __call__
output = mx.fast.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Shapes (1,32,9,24) and (9,9) cannot be broadcast.
Modified schema_helper.py for a trace
def _evaluate_prompt(
self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None
):
if prior_prompt:
i = 0
for i, t in enumerate(prior_prompt):
# Need to leave at least one token to evaluate because we don't
# save the past logits.
if i >= len(prompt) - 1 or prompt[i] != t:
break
cache = prior_cache
for layer_cache in cache:
layer_cache.reuse(len(prompt), i)
tokens = prompt[i:]
print('CACHED', tokens, prompt)
else:
cache = ReusableKVCache.for_model(self.model)
tokens = prompt
print('UNCACHED', tokens)
logits = self.model(mx.array(tokens)[None], cache)
return logits, cache
First run of the shorter prompt displays:
UNCACHED [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]
Already notice the repeated 32007, which is the Phi-3 '<|end|>' token. This is probably not good. Identical run again:
CACHED [32007] [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]
Expected logic, with nothing but that end token post-cache. Now the longer prompt:
CACHED [6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007] [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
End prompt is re-doubled.
At this point I don't know whether this tokenizer oddness is what leads to the shape error, but it's a start for investigating.
Quick look at the Phi-3 tokenizer:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained(
# Should be same tokenizer as microsoft/Phi-3-mini-128k-instruct-4bit
'mlx-community/Phi-3-mini-128k-instruct-4bit'
)
S = 'Hello<|end|>'
ids = tokenizer.encode(S, add_special_tokens=False)
print(ids)
S_decode = tokenizer.decode(ids)
print(repr(S_decode))
S = 'Hello<|end|>'
ids = tokenizer.encode(S, add_special_tokens=True)
print(ids)
S_decode = tokenizer.decode(ids)
print(repr(S_decode))
Output:
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
[15043, 32007]
'Hello<|end|>'
[15043, 32007]
'Hello<|end|>'
The 'Special tokens' warning comes up as soon as you load the tokenizer, and has nothing to do with , add_special_tokens=True|False later on.
repr of tokenizer:
LlamaTokenizerFast(name_or_path='mlx-community/Phi-3-mini-128k-instruct-4bit', vocab_size=32000, model_max_length=131072, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '<|end|>', 'unk_token': '<unk>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False), added_tokens_decoder={
0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
2: AddedToken("</s>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=False),
32000: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
32001: AddedToken("<|assistant|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32002: AddedToken("<|placeholder1|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32003: AddedToken("<|placeholder2|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32004: AddedToken("<|placeholder3|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32005: AddedToken("<|placeholder4|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32006: AddedToken("<|system|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32007: AddedToken("<|end|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32008: AddedToken("<|placeholder5|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32009: AddedToken("<|placeholder6|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
32010: AddedToken("<|user|>", rstrip=True, lstrip=False, single_word=False, normalized=False, special=True),
}
So yes, Phi-3 uses the Llama tokenizer. Notice that the special tokens are added with rstrip=True, i.e. with ws normalization.
A trimmed down repro case:
import mlx.core as mx
from toolio.schema_helper import Model, ReusableKVCache
m = Model()
m.load('mlx-community/Phi-3-mini-128k-instruct-4bit')
from mlx_lm.models.base import KVCache
cache = ReusableKVCache.for_model(m.model)
tokens1 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]
logits = m.model(mx.array(tokens1)[None], cache)
cached_prompt = logits
, prior_prompt: list[int] = None, prior_cache=None
tokens2 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
tokens2_postcache = [6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]
for layer_cache in cache:
layer_cache.reuse(len(tokens2), len(tokens2)-1)
logits = m.model(mx.array(tokens2_postcache)[None], cache)
Result: ValueError: Shapes (1,32,9,32) and (9,9) cannot be broadcast.
Note: just blindly replacing all cases of 32007, 32007 merely tweaked the error: ValueError: Shapes (1,32,8,30) and (8,8) cannot be broadcast.
cache = ReusableKVCache.for_model(m.model)
tokens1 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007]
logits = m.model(mx.array(tokens1)[None], cache)
cached_prompt = logits
, prior_prompt: list[int] = None, prior_cache=None
tokens2 = [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007]
tokens2_postcache = [6804, 505, 306, 6091, 393, 1434, 29973, 32007]
for layer_cache in cache:
layer_cache.reuse(len(tokens2), len(tokens2)-1)
logits = m.model(mx.array(tokens2_postcache)[None], cache)
For now I've got around this by disabling cache prompting by default. I'll leave the ticket open, though, because it would be nice to work a proper fix.