Metal thread safety
Proposed changes
These changes are an attempt to improve thread safety for the metal backend. This is related to #2067 Please let me know what you think.
Checklist
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes - [ ] I have added tests that prove my fix is effective or that my feature works
- [ ] I have updated the necessary documentation (if needed)
This looks interesting. But in general I'm not convinced we need to go this route of fine-grained locking. It might work just as well, maybe even better, and be a lot cleaner/ faster to have a lock where we do the task submission in the main eval loop.
A couple higher level comments:
- It would be good to add tests which show the problem and that the solution works
- We'll need to do some performance benchmarking to be sure there are no regressions from any solution we come up with
Thank you! Let me see what I can do along those lines.
The tests definitely make sense, I wasn't sure what performance benchmark made sense along these lines. If the existing ones are fine and see what impact the changes have?
If the existing ones are fine and see what impact the changes have?
I would do more end-to-end benchmarks. And focuse on more latency sensitive ones since this type of change matters there. So for example LM inference with a smallish LM (like 4-bit 1-3B in size) would be a good place to start (you can use mlx-lm for that).
I'm still working on some good simple tests, I ran in to a few more errors with the prior proposed changes. But I wanted to ask what you thought of this approach, I appreciate any feedback.
I've spot checked the default model with mlx_lm.generate (mlx-community/Llama-3.2-3B-Instruct-4bit), and didn't see any noticeable differences, but I'll do a more robust benchmark with a wider range of model sizes too like you'd suggested.
I like this new approach as it's much simpler. Though I do wonder about the possibility of deadlock. Say we have two streams:
Stream A is waiting on the output of Stream B Stream A is holding the metal lock in synchronize Stream B gets stuck waiting to get the lock so it can run the eval
Something like that seems plausible in a multi-threaded setup. I'm not sure it's necessarily a dealbreaker because sharing graphs across threads is not a good idea for other reasons. But it would be good to setup up a few C++ tests to really exercise the multi-threaded cases we expect this to work for.
I've added a few tests, how do they look to you?
The changes caused one test around buffers to very occasionally fail (tests/array_tests.cpp "test array shared buffer"), I think related to how the deleter was handled with doctest and the test ending. I added a synchronize call to that test, if that makes sense there.
I ran a few benchmarks, apologies for the delay! @awni Results are below, prompt tps and generated tps.
Prompt TPS
| Model | This PR | Current MLX Release |
|---|---|---|
| mlx-community/Llama-3.2-1B-Instruct-4bit | 1942.43 (±10.34) | 1954.47 (±8.67) |
| mlx-community/Llama-3.2-3B-Instruct-4bit | 696.70 (±26.15) | 758.26 (±19.69) |
| mlx-community/Qwen3-0.6B-4bit | 2592.38 (±21.81) | 2605.61 (±24.54) |
| mlx-community/Qwen3-0.6B-6bit | 2491.78 (±12.08) | 2482.57 (±18.17) |
| mlx-community/Qwen3-0.6B-8bit | 2526.31 (±20.74) | 2541.08 (±10.92) |
| mlx-community/Qwen3-1.7B-3bit | 1121.43 (±3.49) | 1098.90 (±35.14) |
| mlx-community/Qwen3-1.7B-4bit | 1124.30 (±13.53) | 1134.61 (±4.48) |
Generation TPS
| Model | This PR | Current MLX Release |
|---|---|---|
| mlx-community/Llama-3.2-1B-Instruct-4bit | 273.97 (±0.32) | 274.73 (±0.44) |
| mlx-community/Llama-3.2-3B-Instruct-4bit | 111.19 (±0.68) | 111.66 (±0.25) |
| mlx-community/Qwen3-0.6B-4bit | 290.42 (±4.89) | 283.89 (±10.23) |
| mlx-community/Qwen3-0.6B-6bit | 287.13 (±6.04) | 264.96 (±14.38) |
| mlx-community/Qwen3-0.6B-8bit | 261.41 (±1.57) | 258.06 (±0.56) |
| mlx-community/Qwen3-1.7B-3bit | 210.59 (±0.25) | 205.58 (±2.31) |
| mlx-community/Qwen3-1.7B-4bit | 185.28 (±0.44) | 184.66 (±0.38) |
The benchmark was pretty simple, prompt was very short (could make it longer). I set the max tokens to 1000 (which the qwen models sometimes reached in my benchmark). Here's the code too for reference. More trials could be run, and with a longer prompt, but hopefully this gives a decent idea on the time difference.
from mlx_lm import load, stream_generate
import pandas as pd
max_tokens = 1_000
verbose = False
warmup_count = 3
num_trials = 10
df_results = pd.DataFrame()
checkpoints = [
"mlx-community/Llama-3.2-1B-Instruct-4bit",
"mlx-community/Llama-3.2-3B-Instruct-4bit",
"mlx-community/Qwen3-0.6B-4bit",
"mlx-community/Qwen3-0.6B-6bit",
"mlx-community/Qwen3-0.6B-8bit",
"mlx-community/Qwen3-1.7B-3bit",
"mlx-community/Qwen3-1.7B-4bit",
]
for checkpoint in checkpoints:
model, tokenizer = load(path_or_hf_repo=checkpoint)
prompt = "Hello! I'm teaching a science class on our solar system and wanted to ask for your help! " \
"Could you tell what the planets in our solar system are called, and a little about each one?"
conversation = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
conversation=conversation, add_generation_prompt=True
)
for _ in range(warmup_count):
text = ""
for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
if verbose:
print(response.text, end="", flush=True)
text += response.text
for i in range(num_trials):
text = ""
for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
if verbose:
print(response.text, end="", flush=True)
text += response.text
response_dict = {
'model': checkpoint,
'trial': i,
'prompt_tokens': response.prompt_tokens,
'prompt_tps': response.prompt_tps,
'generation_tokens': response.generation_tokens,
'generation_tps': response.generation_tps,
'peak_memory': response.peak_memory,
}
df_trial = pd.DataFrame(response_dict, index=[0])
df_results = pd.concat([df_results, df_trial], ignore_index=True)
print(df_results.head())
df_results.to_csv('trial_runs.csv', index=False)
As usual, any feedback is greatly appreciated!
The Llama-3.2-3B-Instruct-4bit prompt processing results are a bit concerning since it's the only result where throughput reduction is outside of the variability in your testing, and it also happens to be the largest model. It seems like it'd be good to benchmark some larger models to see if a trend develops.
It does seem to be slightly slower (with some high variance on the prompt tps). Not sure where to go from here. If this makes the PR a no go, if it's possible I can try relaxing some of the mutex locks, or if it's something with my benchmark.
Prompt TPS
| Model | This PR | Current MLX Release |
|---|---|---|
| mlx-community/Llama-3.2-3B-Instruct-4bit | 733.51 (±41.78) | 768.42 (±1.88) |
| mlx-community/gemma-3-4b-it-4bit | 478.45 (±3.77) | 484.88 (±7.67) |
| mlx-community/gemma-3-12b-it-4bit | 170.33 (±0.59) | 170.33 (±4.11) |
| mlx-community/gemma-3-12b-it-8bit | 125.07 (±36.62) | 154.84 (±0.94) |
Generation TPS
| Model | This PR | Current MLX Release |
|---|---|---|
| mlx-community/Llama-3.2-3B-Instruct-4bit | 111.53 (±0.84) | 111.91 (±0.14) |
| mlx-community/gemma-3-4b-it-4bit | 85.00 (±0.06) | 86.37 (±0.20) |
| mlx-community/gemma-3-12b-it-4bit | 31.84 (±0.03) | 32.33 (±0.05) |
| mlx-community/gemma-3-12b-it-8bit | 17.43 (±0.69) | 18.25 (±0.01) |