mlx icon indicating copy to clipboard operation
mlx copied to clipboard

JIT compile option for binary minimization

Open awni opened this issue 1 year ago • 2 comments

  • Adds a build flag MLX_METAL_JIT to reduce the Metal library size by using runtime compilation.
  • Big refactor of unary, binary, ternary, copy, scatter, gather to allow JIT compilation
  • Current MTL library size 15M mlx.metallib

awni avatar May 08 '24 17:05 awni

Benchmarks:

No degradation in token generation:

python -m mlx_lm.generate --model mlx-community/NeuralBeagle14-7B-4bit-mlx --prompt "Write a story about Albert Einstein" --temp 0.0 --max-tokens 256
Pre:
Prompt: 219.423 tokens-per-sec
Generation: 107.316 tokens-per-sec

Post:
Prompt: 219.580 tokens-per-sec
Generation: 107.562 tokens-per-sec

Transformer training:

Pre: Iter 30: Train loss 7.943, It/sec 5.911, Peak memory 5.534 (GB)
Post: Iter 30: Train loss 7.923, It/sec 5.912, Peak memory 5.534 (GB)

LeNet training:

Pre: Test accuracy 0.982, Time 2.792 (s)
Post: Test accuracy 0.983, Time 2.798 (s)

MNIST:

Pre: Test accuracy 0.937, Time 0.639 (s)
Post: Test accuracy 0.929, Time 0.638 (s)

awni avatar May 09 '24 23:05 awni

@jagrit06 @angeloskath I think this is ready for review.

awni avatar May 16 '24 03:05 awni

For a review, the main thing to look at is:

  • Updated way the compiled includes are made: metal/CMakeLists.txt, metal/jit/includes.h, and metal/make_compiled_preamble.sh
  • The way primitives get or build kernels: metal/kernels.h and the corresponding implementations in metal/jit_kernels.cpp and metal/nojit_kernels.cpp
  • Look at an example for how that works (mostly reorganizing code): e.g. metal/kernels/unary.h, metal/kernels/unary.metal, metal/unary.cpp and the format template in metal/jit/unary.h. They all follow the same pattern.

awni avatar May 22 '24 14:05 awni