Coalesced memory read for a slightly faster LLM interence
Background
I enjoyed reading the posts related to this tweet: https://twitter.com/awnihannun/status/1776275621926375498 I'd like to contribute to the MLX side in the challenge with Llama.cpp, even if it's just a small improvement. :)
Proposed changes
This PR modifies the way scales and biases are read from memory, making the access more coalesced.
Although the data in these variables are sparsely arranged for a SIMD thread, they can be read in a more adjacent manner when considering all SIMD threads together.
Performance Benchmarks
Performance improved around 0.2% with this change.
- Machine: M3MAX (128GB)
- LLM: mistral-7b (4Q)
- Client: https://github.com/ml-explore/mlx-examples
- Command:
python -m mlx_lm.generate --model mlx_model --prompt "write a quicksort algorithm in c++" --max-tokens 256 - Before: main(tag: v0.9.1)
- After: this PR
Token Processing Speed (tokens per second, higher is better, 15 runs):
| Before | After | Speedup | |
|---|---|---|---|
| Median | 68.554 | 68.680 | 0.18% |
| Mean | 68.557 | 68.734 | 0.26% |
| Min | 68.210 | 68.592 | 0.56% |
| Max | 68.800 | 69.087 | 0.42% |
Checklist
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-filesto format my code and installed pre-commit prior to committing changes - [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the necessary documentation (if needed)
Additional Notes
- All tests passed.
- Applying
clang-formattoquantized.metalresulted in significant formatting changes. To avoid overwhelming diffs, I opted to bypass the linting step for this file.
I spent quite some time playing with it and it is really really hard to figure out if there is any performance benefit. I changed the unrolled read to a constexpr forloop that would get rid of the static assertion and run 100 generations using the for loop twice (one of them the pointer advancement is done later), this PR and v0.9.1 .
I get the following graph.
I believe that it doesn't really matter and that the variation has to do with whatever throttling or any other operation that is happening on the machine at the time. There is absolutely no reason the for loop should be faster than the unrolled loop and even if it is, it is a tiny amount.
@youknow04 let me know what you think.
Thank you for your detailed feedback and insights. Let me address each of your points:
-
Tiny performance gain issue Matrix-vector multiplication is a fundamental operation and the core computational load in Transformers. Optimizing this primitive operation can indeed have a persistent effect on all variants.
-
Faster for loop issue First, I should mention that this is only my second time programming in Metal, so I don't have expertise in the Metal framework itself. However, I have done such low-latency optimizations many times, and unrolling is not always faster than a for loop, especially when the unrolled code is too long. The code itself should be in memory, so larger code size implies more memory overhead, such as cache misses. But I didn't expect the for loop to be faster than the unrolled version here. I appreciate your intensive checking.
-
Experiment issue I don't think it's a throttling issue because I manually tested each run with checks. I tested performance with GPU monitoring, ran the LLM when GPU utilization was idle, executed each run with enough rest to prevent throttling issues, and the cooling fan was not triggered in my M3MAX (128G) for all tests.
I think we now need the LLN(Law of Large Numbers) to test this tiny gain in the LLM. MLX is still in its early stages, and you must have a lot of work to do. I want to support the MLX team and really don't want to take your time for this hacky gain. I will make the test code for intensive test and share the result again.
I tried to reproduce the results using the following steps:
- reboot my mac
- run
sudo sysctl iogpu.wired_lwm_mb=100000 - run following python script
- this is randomized script to prevent possible biases.
Used python code to reproduce
import json
import os
import random
import statistics
import subprocess
import time
from dataclasses import dataclass
MLX_LM_PATH = os.getcwd()
MLX_PATH = os.path.expanduser("~/workspace/mlx")
NUM_BATCH = 32
NUM_ITER = 999
RESULT_FILE = "result.jsonl"
@dataclass
class Target:
branch: str
commit_hash: str
targets = [
Target("main", "bddf23f175726a57f0e443cd45518c0757daa166"),
Target("coalesced-mem", "ab58e3718a3e31687e3ef5e8914034855990454f"),
Target("coalesced-mem-unroll-t", "28057ab228de7d1b1042622559b4a1ef7ba14a12"),
# Target("coalesced-mem-for-1", "893a6327b7ec1d2b391c6e8e27822fecbf5d0e04"),
# Target("coalesced-mem-for-2", "af9a0269a35ce68409e4f4f5e09d6f82da55c835"),
]
def run_command(command: list[str]):
return subprocess.check_output(command).decode("utf-8")
def setup_mlx(target: Target):
os.chdir(MLX_PATH)
run_command(["git", "switch", target.branch])
print(f"installing mlx[{target.branch}]")
run_command(["pip", "install", "."])
pip_list_result = subprocess.check_output(["pip", "list"]).decode("utf-8")
if f"+{target.commit_hash[:8]}" not in pip_list_result:
raise RuntimeError(f"wrong MLX version installed. {target.branch}")
os.chdir(MLX_LM_PATH)
def show_result():
parsed: dict[str, list[float]] = {t.branch: [] for t in targets}
with open(RESULT_FILE, "r") as f:
for line in f:
result = json.loads(line)
for r in result:
parsed[r["branch"]].append(float(r["tps"]))
for branch, tps_values in parsed.items():
mean_tps = statistics.mean(tps_values)
median_tps = statistics.median(tps_values)
print(f"{branch}: Mean = {mean_tps}, Median = {median_tps} Min = {min(tps_values)}, Max = {max(tps_values)}")
if __name__ == "__main__":
os.chdir(MLX_LM_PATH)
for i in range(NUM_ITER):
r_samples = random.sample(targets, len(targets))
print(f"{i}th iteration with {r_samples}")
sample_result: list[dict[str, str|float]] = []
for target in r_samples:
setup_mlx(target)
for b in range(NUM_BATCH):
time.sleep(2) # to prevent throttling
llm_result = run_command(["python", "-m", "mlx_lm.generate", "--model", "mlx_mistral7b-8q", "--prompt", "'write a quicksort algorithm in c++'", "--max-tokens", "128"])
tps = float(llm_result.split("Generation: ")[1].split(" tokens-per-sec")[0])
sample_result.append({
"branch": target.branch,
"iter": i,
"batch": b,
"tps": tps,
})
with open(RESULT_FILE, "a") as f:
f.write(json.dumps(sample_result)+"\n")
show_result()
and got this result
-
main: v0.9.1 -
coalesced-mem: this pr -
coalesced-mem-unroll-t: this PR with https://github.com/ml-explore/mlx/pull/803 -
coalesced-mem-for-1: this PR using one for loop -
coalesced-mem-for-2: this PR using two for loop
| Configuration | Mean | Median | Min | Max |
|---|---|---|---|---|
| coalesced-mem-for-1 | 69.817 | 69.808 | 69.427 | 71.343 |
| coalesced-mem-for-2 | 69.832 | 69.813 | 69.296 | 70.835 |
| main | 69.848 | 69.840 | 69.339 | 71.243 |
| coalesced-mem | 69.875 | 69.861 | 69.421 | 70.648 |
| coalesced-mem-unroll-t | 69.966 | 69.941 | 69.500 | 71.985 |
| Configuration | Mean | Median | Min | Max |
|---|---|---|---|---|
| main | 42.391 | 42.383 | 42.218 | 42.806 |
| coalesced-mem | 42.416 | 42.404 | 42.225 | 42.864 |
| coalesced-mem-unroll-t | 42.476 | 42.468 | 42.295 | 42.744 |
Result Summary:
- After
rebootingand theiogpu.wired_lwm_mbsetting, all inference speeds increased, and the speed gap between different configurations was reduced. - The rank order remained consistent.
- The combination of this PR and https://github.com/ml-explore/mlx/pull/803 shows the best performance.
- We now have clean results, and throttling or other operational issues have minimal impact due to randomization and the LLN.
- This PR delivers improved performance, and it is even faster when combined with https://github.com/ml-explore/mlx/pull/803.
In my opinion, for truly low-latency code, complex implementations are inevitable. we also have room for improvement when we choose this direction.
@angeloskath let me know what you think.
It seems you have no interest in this PR. I will close this.