mlx
mlx copied to clipboard
JIT compile option for binary minimization
- Adds a build flag
MLX_METAL_JITto reduce the Metal library size by using runtime compilation. - Big refactor of unary, binary, ternary, copy, scatter, gather to allow JIT compilation
- Current MTL library size
15M mlx.metallib
Benchmarks:
No degradation in token generation:
python -m mlx_lm.generate --model mlx-community/NeuralBeagle14-7B-4bit-mlx --prompt "Write a story about Albert Einstein" --temp 0.0 --max-tokens 256
Pre:
Prompt: 219.423 tokens-per-sec
Generation: 107.316 tokens-per-sec
Post:
Prompt: 219.580 tokens-per-sec
Generation: 107.562 tokens-per-sec
Transformer training:
Pre: Iter 30: Train loss 7.943, It/sec 5.911, Peak memory 5.534 (GB)
Post: Iter 30: Train loss 7.923, It/sec 5.912, Peak memory 5.534 (GB)
LeNet training:
Pre: Test accuracy 0.982, Time 2.792 (s)
Post: Test accuracy 0.983, Time 2.798 (s)
MNIST:
Pre: Test accuracy 0.937, Time 0.639 (s)
Post: Test accuracy 0.929, Time 0.638 (s)
@jagrit06 @angeloskath I think this is ready for review.
For a review, the main thing to look at is:
- Updated way the compiled includes are made:
metal/CMakeLists.txt,metal/jit/includes.h, andmetal/make_compiled_preamble.sh - The way primitives get or build kernels:
metal/kernels.hand the corresponding implementations inmetal/jit_kernels.cppandmetal/nojit_kernels.cpp - Look at an example for how that works (mostly reorganizing code): e.g.
metal/kernels/unary.h,metal/kernels/unary.metal,metal/unary.cppand the format template inmetal/jit/unary.h. They all follow the same pattern.