mlx
mlx copied to clipboard
MLX in pyhpc-benchmarks repo
I am the maintainer of pyhpc-benchmarks
, a repository that compares the performance of various JIT compilers and ML frameworks on scientific computing workloads (mostly finite difference stencils from an ocean model).
I have recently added MLX as a possible backend. I would like to give you the opportunity to correct inefficiencies or potential mistakes in my implementation before I publish the results. The branch with MLX support is here, benchmarks are in benchmarks/*/*_mlx.py
.
Right now, the performance of MLX on these workloads is hovering between decent and not good (slower than NumPy in some cases). I am guessing this is because of the lack of true inplace operations, since benchmarks use a lot of slicing and selective updates.
That said the ergonomics of implementing benchmarks in MLX were excellent, especially considering how recently MLX was released - everything worked pretty much out of the box. Great job!
Example output on my M1 MacBook follows.
Equation of state benchmark
(elementwise ops, no slicing)
$ python run.py benchmarks/equation_of_state/ --device cpu
benchmarks.equation_of_state
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numba 100,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 2.711
4,096 taichi 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.002 2.522
4,096 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.003 2.007
4,096 aesara 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.003 1.718
4,096 numpy 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.003 1.000
4,096 tensorflow 10,000 0.001 0.001 0.000 0.000 0.001 0.001 0.008 0.838
4,096 pytorch 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.003 0.747
4,096 mlx 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.004 0.458
16,384 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.003 18.125
16,384 taichi 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.003 6.375
16,384 numba 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.002 6.117
16,384 tensorflow 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.005 4.873
16,384 aesara 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.003 4.263
16,384 mlx 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.004 2.642
16,384 numpy 10,000 0.006 0.001 0.003 0.005 0.005 0.006 0.016 1.000
16,384 pytorch 1,000 0.006 0.002 0.003 0.005 0.006 0.007 0.014 0.956
65,536 jax 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.005 15.042
65,536 taichi 1,000 0.003 0.000 0.003 0.003 0.003 0.003 0.005 4.457
65,536 tensorflow 1,000 0.003 0.000 0.003 0.003 0.003 0.003 0.005 4.256
65,536 numba 1,000 0.004 0.000 0.003 0.003 0.004 0.004 0.005 4.094
65,536 aesara 1,000 0.005 0.000 0.005 0.005 0.005 0.005 0.006 2.975
65,536 mlx 1,000 0.005 0.000 0.005 0.005 0.005 0.005 0.007 2.850
65,536 numpy 1,000 0.015 0.002 0.011 0.013 0.014 0.015 0.022 1.000
65,536 pytorch 1,000 0.022 0.004 0.013 0.020 0.022 0.025 0.041 0.648
262,144 jax 1,000 0.002 0.000 0.002 0.002 0.002 0.002 0.004 20.878
262,144 tensorflow 1,000 0.008 0.000 0.008 0.008 0.008 0.008 0.010 5.672
262,144 taichi 1,000 0.012 0.000 0.011 0.011 0.011 0.012 0.014 4.065
262,144 numba 1,000 0.013 0.000 0.013 0.013 0.013 0.013 0.014 3.647
262,144 aesara 1,000 0.017 0.000 0.017 0.017 0.017 0.017 0.019 2.696
262,144 mlx 1,000 0.022 0.001 0.021 0.022 0.022 0.022 0.028 2.111
262,144 numpy 100 0.047 0.005 0.041 0.044 0.045 0.047 0.070 1.000
262,144 pytorch 100 0.057 0.010 0.039 0.051 0.055 0.062 0.084 0.818
1,048,576 jax 1,000 0.008 0.001 0.007 0.008 0.008 0.009 0.013 26.042
1,048,576 tensorflow 100 0.033 0.001 0.032 0.032 0.033 0.033 0.036 6.600
1,048,576 taichi 100 0.048 0.001 0.047 0.048 0.048 0.049 0.051 4.487
1,048,576 numba 100 0.053 0.001 0.052 0.052 0.053 0.054 0.057 4.074
1,048,576 aesara 100 0.071 0.001 0.070 0.071 0.071 0.071 0.074 3.044
1,048,576 mlx 100 0.102 0.004 0.099 0.100 0.100 0.101 0.127 2.131
1,048,576 pytorch 100 0.184 0.009 0.167 0.178 0.182 0.189 0.209 1.178
1,048,576 numpy 100 0.216 0.011 0.194 0.208 0.214 0.223 0.252 1.000
4,194,304 jax 100 0.035 0.002 0.032 0.034 0.035 0.036 0.040 37.595
4,194,304 tensorflow 100 0.131 0.004 0.126 0.128 0.130 0.133 0.141 10.101
4,194,304 taichi 100 0.191 0.003 0.187 0.190 0.191 0.192 0.197 6.914
4,194,304 numba 100 0.211 0.003 0.207 0.208 0.210 0.213 0.219 6.270
4,194,304 aesara 100 0.280 0.003 0.273 0.279 0.280 0.282 0.289 4.713
4,194,304 mlx 10 0.445 0.028 0.403 0.425 0.450 0.462 0.500 2.967
4,194,304 pytorch 10 0.815 0.066 0.753 0.768 0.775 0.850 0.958 1.620
4,194,304 numpy 10 1.321 0.350 0.952 1.012 1.170 1.694 1.874 1.000
(time in wall seconds, less is better)
Isoneutral diffusion benchmark
(4D slicing and index shifts)
$ benchmarks.isoneutral_mixing
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.001 0.001 0.000 0.000 0.001 0.001 0.010 2.456
4,096 numba 10,000 0.001 0.001 0.000 0.000 0.000 0.001 0.011 2.381
4,096 taichi 10,000 0.001 0.001 0.001 0.001 0.001 0.001 0.012 2.240
4,096 aesara 10,000 0.001 0.001 0.001 0.001 0.001 0.001 0.015 1.222
4,096 numpy 10,000 0.002 0.001 0.002 0.002 0.002 0.002 0.014 1.000
4,096 pytorch 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.012 0.765
4,096 mlx 1,000 0.010 0.001 0.009 0.009 0.009 0.010 0.023 0.186
16,384 numba 1,000 0.003 0.001 0.002 0.002 0.002 0.003 0.006 3.558
16,384 taichi 1,000 0.003 0.000 0.003 0.003 0.003 0.003 0.006 3.305
16,384 jax 1,000 0.003 0.001 0.002 0.003 0.003 0.004 0.007 2.810
16,384 aesara 1,000 0.006 0.001 0.005 0.006 0.006 0.006 0.014 1.495
16,384 numpy 1,000 0.009 0.002 0.005 0.008 0.009 0.010 0.017 1.000
16,384 pytorch 1,000 0.011 0.002 0.005 0.009 0.011 0.012 0.019 0.848
16,384 mlx 100 0.036 0.001 0.034 0.035 0.035 0.036 0.048 0.254
65,536 numba 1,000 0.010 0.001 0.009 0.010 0.010 0.011 0.014 2.199
65,536 jax 1,000 0.011 0.001 0.009 0.010 0.011 0.011 0.015 2.097
65,536 taichi 100 0.011 0.001 0.011 0.011 0.011 0.011 0.015 2.034
65,536 numpy 1,000 0.023 0.003 0.018 0.021 0.022 0.023 0.038 1.000
65,536 aesara 100 0.024 0.001 0.023 0.024 0.024 0.025 0.028 0.930
65,536 pytorch 100 0.028 0.004 0.019 0.025 0.027 0.030 0.039 0.814
65,536 mlx 100 0.139 0.002 0.136 0.137 0.138 0.139 0.145 0.163
262,144 jax 100 0.023 0.002 0.021 0.022 0.023 0.024 0.028 3.212
262,144 numba 100 0.039 0.001 0.037 0.038 0.038 0.039 0.043 1.953
262,144 taichi 100 0.046 0.001 0.045 0.045 0.046 0.047 0.050 1.626
262,144 numpy 100 0.075 0.010 0.066 0.069 0.070 0.078 0.119 1.000
262,144 pytorch 100 0.078 0.018 0.063 0.068 0.072 0.079 0.148 0.967
262,144 aesara 100 0.092 0.004 0.086 0.089 0.091 0.094 0.110 0.821
262,144 mlx 10 0.515 0.005 0.505 0.512 0.514 0.520 0.523 0.146
1,048,576 jax 10 0.108 0.005 0.098 0.105 0.108 0.111 0.118 3.306
1,048,576 numba 10 0.163 0.007 0.153 0.155 0.165 0.169 0.172 2.198
1,048,576 taichi 10 0.186 0.005 0.179 0.182 0.187 0.190 0.193 1.921
1,048,576 numpy 10 0.357 0.013 0.344 0.349 0.356 0.361 0.389 1.000
1,048,576 aesara 10 0.430 0.032 0.402 0.414 0.421 0.428 0.521 0.831
1,048,576 pytorch 10 0.473 0.026 0.401 0.474 0.477 0.484 0.500 0.756
1,048,576 mlx 10 2.177 0.023 2.137 2.161 2.173 2.195 2.214 0.164
4,194,304 jax 10 0.569 0.054 0.474 0.538 0.565 0.617 0.652 3.016
4,194,304 numba 10 0.669 0.013 0.651 0.657 0.669 0.681 0.688 2.566
4,194,304 taichi 10 0.726 0.020 0.709 0.711 0.719 0.725 0.766 2.366
4,194,304 numpy 10 1.718 0.137 1.526 1.586 1.740 1.788 1.931 1.000
4,194,304 aesara 10 1.896 0.117 1.695 1.810 1.922 1.995 2.042 0.906
4,194,304 pytorch 10 1.937 0.105 1.826 1.842 1.935 1.981 2.186 0.887
4,194,304 mlx 10 8.707 0.077 8.633 8.664 8.681 8.718 8.911 0.197
TKE benchmark
(3D slicing and iterative tridiagonal solves)
$ python run.py benchmarks/turbulent_kinetic_energy/ --device cpu
benchmarks.turbulent_kinetic_energy
===================================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 4.930
4,096 numba 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 1.917
4,096 numpy 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.005 1.000
4,096 pytorch 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.007 0.697
4,096 mlx 1,000 0.007 0.000 0.006 0.007 0.007 0.007 0.011 0.121
16,384 jax 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.005 5.496
16,384 numba 1,000 0.002 0.000 0.001 0.001 0.001 0.002 0.003 2.128
16,384 pytorch 1,000 0.003 0.001 0.002 0.003 0.003 0.004 0.008 1.035
16,384 numpy 1,000 0.003 0.001 0.002 0.003 0.003 0.004 0.021 1.000
16,384 mlx 1,000 0.022 0.001 0.020 0.021 0.022 0.022 0.033 0.160
65,536 jax 1,000 0.003 0.001 0.002 0.003 0.003 0.003 0.023 4.253
65,536 numba 1,000 0.006 0.001 0.005 0.005 0.006 0.006 0.009 2.011
65,536 pytorch 100 0.011 0.002 0.007 0.010 0.011 0.012 0.016 1.104
65,536 numpy 100 0.012 0.002 0.009 0.011 0.011 0.012 0.024 1.000
65,536 mlx 100 0.080 0.002 0.077 0.079 0.079 0.080 0.087 0.149
262,144 jax 100 0.009 0.001 0.007 0.008 0.009 0.009 0.014 4.411
262,144 numba 100 0.021 0.002 0.019 0.020 0.021 0.022 0.031 1.832
262,144 pytorch 100 0.034 0.009 0.026 0.029 0.031 0.033 0.065 1.155
262,144 numpy 100 0.039 0.006 0.034 0.036 0.037 0.038 0.061 1.000
262,144 mlx 10 0.264 0.004 0.260 0.261 0.263 0.266 0.273 0.147
1,048,576 jax 10 0.047 0.005 0.042 0.044 0.045 0.046 0.059 4.178
1,048,576 numba 10 0.088 0.004 0.084 0.084 0.087 0.091 0.097 2.220
1,048,576 numpy 10 0.195 0.024 0.168 0.175 0.187 0.218 0.234 1.000
1,048,576 pytorch 10 0.205 0.012 0.190 0.194 0.204 0.217 0.219 0.952
1,048,576 mlx 10 1.182 0.015 1.167 1.168 1.177 1.195 1.204 0.165
4,194,304 jax 10 0.249 0.047 0.202 0.227 0.230 0.249 0.358 3.832
4,194,304 numba 10 0.358 0.035 0.324 0.333 0.340 0.393 0.417 2.663
4,194,304 numpy 10 0.954 0.143 0.736 0.851 0.929 1.041 1.249 1.000
4,194,304 pytorch 10 1.010 0.062 0.923 0.964 1.002 1.052 1.124 0.945
4,194,304 mlx 10 4.874 0.138 4.726 4.786 4.858 4.895 5.235 0.196
GPU is a bit faster than that, for example on isoneutral:
$ python run.py benchmarks/isoneutral_mixing/ --device gpu -b mlx -b numpy
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.013 1.000
4,096 mlx 1,000 0.009 0.002 0.007 0.008 0.008 0.009 0.042 0.202
16,384 numpy 1,000 0.009 0.002 0.006 0.008 0.009 0.010 0.022 1.000
16,384 mlx 1,000 0.009 0.002 0.007 0.008 0.009 0.011 0.045 0.974
65,536 mlx 1,000 0.015 0.004 0.010 0.013 0.015 0.017 0.049 1.454
65,536 numpy 100 0.022 0.003 0.018 0.020 0.021 0.023 0.033 1.000
262,144 mlx 100 0.033 0.009 0.023 0.028 0.030 0.033 0.066 2.430
262,144 numpy 100 0.081 0.013 0.067 0.068 0.077 0.091 0.115 1.000
1,048,576 mlx 100 0.097 0.038 0.074 0.078 0.084 0.090 0.270 4.351
1,048,576 numpy 10 0.421 0.050 0.359 0.376 0.409 0.459 0.501 1.000
4,194,304 mlx 10 0.792 0.135 0.475 0.789 0.819 0.861 0.955 2.337
4,194,304 numpy 10 1.852 0.039 1.777 1.836 1.849 1.879 1.920 1.000
Let me know in case you have any thoughts or concerns. I was planning to run these experiments more thoroughly and publish the results within a couple of weeks.