mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Metal thread safety

Open acsweet opened this issue 8 months ago • 9 comments

Proposed changes

These changes are an attempt to improve thread safety for the metal backend. This is related to #2067 Please let me know what you think.

Checklist

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [ ] I have added tests that prove my fix is effective or that my feature works
  • [ ] I have updated the necessary documentation (if needed)

acsweet avatar Apr 22 '25 00:04 acsweet

This looks interesting. But in general I'm not convinced we need to go this route of fine-grained locking. It might work just as well, maybe even better, and be a lot cleaner/ faster to have a lock where we do the task submission in the main eval loop.

A couple higher level comments:

  • It would be good to add tests which show the problem and that the solution works
  • We'll need to do some performance benchmarking to be sure there are no regressions from any solution we come up with

awni avatar Apr 22 '25 13:04 awni

Thank you! Let me see what I can do along those lines.

The tests definitely make sense, I wasn't sure what performance benchmark made sense along these lines. If the existing ones are fine and see what impact the changes have?

acsweet avatar Apr 22 '25 16:04 acsweet

If the existing ones are fine and see what impact the changes have?

I would do more end-to-end benchmarks. And focuse on more latency sensitive ones since this type of change matters there. So for example LM inference with a smallish LM (like 4-bit 1-3B in size) would be a good place to start (you can use mlx-lm for that).

awni avatar Apr 22 '25 17:04 awni

I'm still working on some good simple tests, I ran in to a few more errors with the prior proposed changes. But I wanted to ask what you thought of this approach, I appreciate any feedback.

I've spot checked the default model with mlx_lm.generate (mlx-community/Llama-3.2-3B-Instruct-4bit), and didn't see any noticeable differences, but I'll do a more robust benchmark with a wider range of model sizes too like you'd suggested.

acsweet avatar Apr 30 '25 23:04 acsweet

I like this new approach as it's much simpler. Though I do wonder about the possibility of deadlock. Say we have two streams:

Stream A is waiting on the output of Stream B Stream A is holding the metal lock in synchronize Stream B gets stuck waiting to get the lock so it can run the eval

Something like that seems plausible in a multi-threaded setup. I'm not sure it's necessarily a dealbreaker because sharing graphs across threads is not a good idea for other reasons. But it would be good to setup up a few C++ tests to really exercise the multi-threaded cases we expect this to work for.

awni avatar May 01 '25 19:05 awni

I've added a few tests, how do they look to you?

The changes caused one test around buffers to very occasionally fail (tests/array_tests.cpp "test array shared buffer"), I think related to how the deleter was handled with doctest and the test ending. I added a synchronize call to that test, if that makes sense there.

acsweet avatar May 07 '25 19:05 acsweet

I ran a few benchmarks, apologies for the delay! @awni Results are below, prompt tps and generated tps.

Prompt TPS

Model This PR Current MLX Release
mlx-community/Llama-3.2-1B-Instruct-4bit 1942.43 (±10.34) 1954.47 (±8.67)
mlx-community/Llama-3.2-3B-Instruct-4bit 696.70 (±26.15) 758.26 (±19.69)
mlx-community/Qwen3-0.6B-4bit 2592.38 (±21.81) 2605.61 (±24.54)
mlx-community/Qwen3-0.6B-6bit 2491.78 (±12.08) 2482.57 (±18.17)
mlx-community/Qwen3-0.6B-8bit 2526.31 (±20.74) 2541.08 (±10.92)
mlx-community/Qwen3-1.7B-3bit 1121.43 (±3.49) 1098.90 (±35.14)
mlx-community/Qwen3-1.7B-4bit 1124.30 (±13.53) 1134.61 (±4.48)

Generation TPS

Model This PR Current MLX Release
mlx-community/Llama-3.2-1B-Instruct-4bit 273.97 (±0.32) 274.73 (±0.44)
mlx-community/Llama-3.2-3B-Instruct-4bit 111.19 (±0.68) 111.66 (±0.25)
mlx-community/Qwen3-0.6B-4bit 290.42 (±4.89) 283.89 (±10.23)
mlx-community/Qwen3-0.6B-6bit 287.13 (±6.04) 264.96 (±14.38)
mlx-community/Qwen3-0.6B-8bit 261.41 (±1.57) 258.06 (±0.56)
mlx-community/Qwen3-1.7B-3bit 210.59 (±0.25) 205.58 (±2.31)
mlx-community/Qwen3-1.7B-4bit 185.28 (±0.44) 184.66 (±0.38)

The benchmark was pretty simple, prompt was very short (could make it longer). I set the max tokens to 1000 (which the qwen models sometimes reached in my benchmark). Here's the code too for reference. More trials could be run, and with a longer prompt, but hopefully this gives a decent idea on the time difference.

from mlx_lm import load, stream_generate
import pandas as pd

max_tokens = 1_000
verbose = False
warmup_count = 3
num_trials = 10
df_results = pd.DataFrame()

checkpoints = [
    "mlx-community/Llama-3.2-1B-Instruct-4bit",
    "mlx-community/Llama-3.2-3B-Instruct-4bit",
    "mlx-community/Qwen3-0.6B-4bit",
    "mlx-community/Qwen3-0.6B-6bit",
    "mlx-community/Qwen3-0.6B-8bit",
    "mlx-community/Qwen3-1.7B-3bit",
    "mlx-community/Qwen3-1.7B-4bit",
]

for checkpoint in checkpoints:
    model, tokenizer = load(path_or_hf_repo=checkpoint)

    prompt = "Hello! I'm teaching a science class on our solar system and wanted to ask for your help! " \
        "Could you tell what the planets in our solar system are called, and a little about each one?"
    conversation = [{"role": "user", "content": prompt}]

    prompt = tokenizer.apply_chat_template(
        conversation=conversation, add_generation_prompt=True
    )

    for _ in range(warmup_count):
        text = ""
        for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
            if verbose:
                print(response.text, end="", flush=True)
            text += response.text

    for i in range(num_trials):

        text = ""
        for response in stream_generate(model, tokenizer, prompt, max_tokens=max_tokens):
            if verbose:
                print(response.text, end="", flush=True)
            text += response.text

        response_dict = {
            'model': checkpoint,
            'trial': i,
            'prompt_tokens': response.prompt_tokens,
            'prompt_tps': response.prompt_tps,
            'generation_tokens': response.generation_tokens,
            'generation_tps': response.generation_tps,
            'peak_memory': response.peak_memory,
        }

        df_trial = pd.DataFrame(response_dict, index=[0])
        df_results = pd.concat([df_results, df_trial], ignore_index=True)


print(df_results.head())
df_results.to_csv('trial_runs.csv', index=False)

As usual, any feedback is greatly appreciated!

acsweet avatar May 13 '25 08:05 acsweet

The Llama-3.2-3B-Instruct-4bit prompt processing results are a bit concerning since it's the only result where throughput reduction is outside of the variability in your testing, and it also happens to be the largest model. It seems like it'd be good to benchmark some larger models to see if a trend develops.

altaic avatar May 13 '25 23:05 altaic

It does seem to be slightly slower (with some high variance on the prompt tps). Not sure where to go from here. If this makes the PR a no go, if it's possible I can try relaxing some of the mutex locks, or if it's something with my benchmark.

Prompt TPS

Model This PR Current MLX Release
mlx-community/Llama-3.2-3B-Instruct-4bit 733.51 (±41.78) 768.42 (±1.88)
mlx-community/gemma-3-4b-it-4bit 478.45 (±3.77) 484.88 (±7.67)
mlx-community/gemma-3-12b-it-4bit 170.33 (±0.59) 170.33 (±4.11)
mlx-community/gemma-3-12b-it-8bit 125.07 (±36.62) 154.84 (±0.94)

Generation TPS

Model This PR Current MLX Release
mlx-community/Llama-3.2-3B-Instruct-4bit 111.53 (±0.84) 111.91 (±0.14)
mlx-community/gemma-3-4b-it-4bit 85.00 (±0.06) 86.37 (±0.20)
mlx-community/gemma-3-12b-it-4bit 31.84 (±0.03) 32.33 (±0.05)
mlx-community/gemma-3-12b-it-8bit 17.43 (±0.69) 18.25 (±0.01)

acsweet avatar May 14 '25 07:05 acsweet