mlx icon indicating copy to clipboard operation
mlx copied to clipboard

MLX in pyhpc-benchmarks repo

Open dionhaefner opened this issue 1 year ago • 0 comments

I am the maintainer of pyhpc-benchmarks, a repository that compares the performance of various JIT compilers and ML frameworks on scientific computing workloads (mostly finite difference stencils from an ocean model).

I have recently added MLX as a possible backend. I would like to give you the opportunity to correct inefficiencies or potential mistakes in my implementation before I publish the results. The branch with MLX support is here, benchmarks are in benchmarks/*/*_mlx.py.

Right now, the performance of MLX on these workloads is hovering between decent and not good (slower than NumPy in some cases). I am guessing this is because of the lack of true inplace operations, since benchmarks use a lot of slicing and selective updates.

That said the ergonomics of implementing benchmarks in MLX were excellent, especially considering how recently MLX was released - everything worked pretty much out of the box. Great job!

Example output on my M1 MacBook follows.

Equation of state benchmark

(elementwise ops, no slicing)

$ python run.py benchmarks/equation_of_state/ --device cpu

benchmarks.equation_of_state
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numba        100,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004     2.711
       4,096  taichi        10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.002     2.522
       4,096  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.003     2.007
       4,096  aesara        10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.003     1.718
       4,096  numpy         10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.003     1.000
       4,096  tensorflow    10,000     0.001     0.001     0.000     0.000     0.001     0.001     0.008     0.838
       4,096  pytorch       10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.003     0.747
       4,096  mlx           10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.004     0.458

      16,384  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.003    18.125
      16,384  taichi        10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.003     6.375
      16,384  numba         10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.002     6.117
      16,384  tensorflow    10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.005     4.873
      16,384  aesara        10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.003     4.263
      16,384  mlx           10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.004     2.642
      16,384  numpy         10,000     0.006     0.001     0.003     0.005     0.005     0.006     0.016     1.000
      16,384  pytorch        1,000     0.006     0.002     0.003     0.005     0.006     0.007     0.014     0.956

      65,536  jax           10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.005    15.042
      65,536  taichi         1,000     0.003     0.000     0.003     0.003     0.003     0.003     0.005     4.457
      65,536  tensorflow     1,000     0.003     0.000     0.003     0.003     0.003     0.003     0.005     4.256
      65,536  numba          1,000     0.004     0.000     0.003     0.003     0.004     0.004     0.005     4.094
      65,536  aesara         1,000     0.005     0.000     0.005     0.005     0.005     0.005     0.006     2.975
      65,536  mlx            1,000     0.005     0.000     0.005     0.005     0.005     0.005     0.007     2.850
      65,536  numpy          1,000     0.015     0.002     0.011     0.013     0.014     0.015     0.022     1.000
      65,536  pytorch        1,000     0.022     0.004     0.013     0.020     0.022     0.025     0.041     0.648

     262,144  jax            1,000     0.002     0.000     0.002     0.002     0.002     0.002     0.004    20.878
     262,144  tensorflow     1,000     0.008     0.000     0.008     0.008     0.008     0.008     0.010     5.672
     262,144  taichi         1,000     0.012     0.000     0.011     0.011     0.011     0.012     0.014     4.065
     262,144  numba          1,000     0.013     0.000     0.013     0.013     0.013     0.013     0.014     3.647
     262,144  aesara         1,000     0.017     0.000     0.017     0.017     0.017     0.017     0.019     2.696
     262,144  mlx            1,000     0.022     0.001     0.021     0.022     0.022     0.022     0.028     2.111
     262,144  numpy            100     0.047     0.005     0.041     0.044     0.045     0.047     0.070     1.000
     262,144  pytorch          100     0.057     0.010     0.039     0.051     0.055     0.062     0.084     0.818

   1,048,576  jax            1,000     0.008     0.001     0.007     0.008     0.008     0.009     0.013    26.042
   1,048,576  tensorflow       100     0.033     0.001     0.032     0.032     0.033     0.033     0.036     6.600
   1,048,576  taichi           100     0.048     0.001     0.047     0.048     0.048     0.049     0.051     4.487
   1,048,576  numba            100     0.053     0.001     0.052     0.052     0.053     0.054     0.057     4.074
   1,048,576  aesara           100     0.071     0.001     0.070     0.071     0.071     0.071     0.074     3.044
   1,048,576  mlx              100     0.102     0.004     0.099     0.100     0.100     0.101     0.127     2.131
   1,048,576  pytorch          100     0.184     0.009     0.167     0.178     0.182     0.189     0.209     1.178
   1,048,576  numpy            100     0.216     0.011     0.194     0.208     0.214     0.223     0.252     1.000

   4,194,304  jax              100     0.035     0.002     0.032     0.034     0.035     0.036     0.040    37.595
   4,194,304  tensorflow       100     0.131     0.004     0.126     0.128     0.130     0.133     0.141    10.101
   4,194,304  taichi           100     0.191     0.003     0.187     0.190     0.191     0.192     0.197     6.914
   4,194,304  numba            100     0.211     0.003     0.207     0.208     0.210     0.213     0.219     6.270
   4,194,304  aesara           100     0.280     0.003     0.273     0.279     0.280     0.282     0.289     4.713
   4,194,304  mlx               10     0.445     0.028     0.403     0.425     0.450     0.462     0.500     2.967
   4,194,304  pytorch           10     0.815     0.066     0.753     0.768     0.775     0.850     0.958     1.620
   4,194,304  numpy             10     1.321     0.350     0.952     1.012     1.170     1.694     1.874     1.000

(time in wall seconds, less is better)

Isoneutral diffusion benchmark

(4D slicing and index shifts)

$ benchmarks.isoneutral_mixing
============================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.001     0.001     0.000     0.000     0.001     0.001     0.010     2.456
       4,096  numba         10,000     0.001     0.001     0.000     0.000     0.000     0.001     0.011     2.381
       4,096  taichi        10,000     0.001     0.001     0.001     0.001     0.001     0.001     0.012     2.240
       4,096  aesara        10,000     0.001     0.001     0.001     0.001     0.001     0.001     0.015     1.222
       4,096  numpy         10,000     0.002     0.001     0.002     0.002     0.002     0.002     0.014     1.000
       4,096  pytorch       10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.012     0.765
       4,096  mlx            1,000     0.010     0.001     0.009     0.009     0.009     0.010     0.023     0.186

      16,384  numba          1,000     0.003     0.001     0.002     0.002     0.002     0.003     0.006     3.558
      16,384  taichi         1,000     0.003     0.000     0.003     0.003     0.003     0.003     0.006     3.305
      16,384  jax            1,000     0.003     0.001     0.002     0.003     0.003     0.004     0.007     2.810
      16,384  aesara         1,000     0.006     0.001     0.005     0.006     0.006     0.006     0.014     1.495
      16,384  numpy          1,000     0.009     0.002     0.005     0.008     0.009     0.010     0.017     1.000
      16,384  pytorch        1,000     0.011     0.002     0.005     0.009     0.011     0.012     0.019     0.848
      16,384  mlx              100     0.036     0.001     0.034     0.035     0.035     0.036     0.048     0.254

      65,536  numba          1,000     0.010     0.001     0.009     0.010     0.010     0.011     0.014     2.199
      65,536  jax            1,000     0.011     0.001     0.009     0.010     0.011     0.011     0.015     2.097
      65,536  taichi           100     0.011     0.001     0.011     0.011     0.011     0.011     0.015     2.034
      65,536  numpy          1,000     0.023     0.003     0.018     0.021     0.022     0.023     0.038     1.000
      65,536  aesara           100     0.024     0.001     0.023     0.024     0.024     0.025     0.028     0.930
      65,536  pytorch          100     0.028     0.004     0.019     0.025     0.027     0.030     0.039     0.814
      65,536  mlx              100     0.139     0.002     0.136     0.137     0.138     0.139     0.145     0.163

     262,144  jax              100     0.023     0.002     0.021     0.022     0.023     0.024     0.028     3.212
     262,144  numba            100     0.039     0.001     0.037     0.038     0.038     0.039     0.043     1.953
     262,144  taichi           100     0.046     0.001     0.045     0.045     0.046     0.047     0.050     1.626
     262,144  numpy            100     0.075     0.010     0.066     0.069     0.070     0.078     0.119     1.000
     262,144  pytorch          100     0.078     0.018     0.063     0.068     0.072     0.079     0.148     0.967
     262,144  aesara           100     0.092     0.004     0.086     0.089     0.091     0.094     0.110     0.821
     262,144  mlx               10     0.515     0.005     0.505     0.512     0.514     0.520     0.523     0.146

   1,048,576  jax               10     0.108     0.005     0.098     0.105     0.108     0.111     0.118     3.306
   1,048,576  numba             10     0.163     0.007     0.153     0.155     0.165     0.169     0.172     2.198
   1,048,576  taichi            10     0.186     0.005     0.179     0.182     0.187     0.190     0.193     1.921
   1,048,576  numpy             10     0.357     0.013     0.344     0.349     0.356     0.361     0.389     1.000
   1,048,576  aesara            10     0.430     0.032     0.402     0.414     0.421     0.428     0.521     0.831
   1,048,576  pytorch           10     0.473     0.026     0.401     0.474     0.477     0.484     0.500     0.756
   1,048,576  mlx               10     2.177     0.023     2.137     2.161     2.173     2.195     2.214     0.164

   4,194,304  jax               10     0.569     0.054     0.474     0.538     0.565     0.617     0.652     3.016
   4,194,304  numba             10     0.669     0.013     0.651     0.657     0.669     0.681     0.688     2.566
   4,194,304  taichi            10     0.726     0.020     0.709     0.711     0.719     0.725     0.766     2.366
   4,194,304  numpy             10     1.718     0.137     1.526     1.586     1.740     1.788     1.931     1.000
   4,194,304  aesara            10     1.896     0.117     1.695     1.810     1.922     1.995     2.042     0.906
   4,194,304  pytorch           10     1.937     0.105     1.826     1.842     1.935     1.981     2.186     0.887
   4,194,304  mlx               10     8.707     0.077     8.633     8.664     8.681     8.718     8.911     0.197

TKE benchmark

(3D slicing and iterative tridiagonal solves)

$ python run.py benchmarks/turbulent_kinetic_energy/ --device cpu

benchmarks.turbulent_kinetic_energy
===================================
Running on CPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  jax           10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004     4.930
       4,096  numba         10,000     0.000     0.000     0.000     0.000     0.000     0.000     0.004     1.917
       4,096  numpy         10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.005     1.000
       4,096  pytorch       10,000     0.001     0.000     0.001     0.001     0.001     0.001     0.007     0.697
       4,096  mlx            1,000     0.007     0.000     0.006     0.007     0.007     0.007     0.011     0.121

      16,384  jax            1,000     0.001     0.000     0.001     0.001     0.001     0.001     0.005     5.496
      16,384  numba          1,000     0.002     0.000     0.001     0.001     0.001     0.002     0.003     2.128
      16,384  pytorch        1,000     0.003     0.001     0.002     0.003     0.003     0.004     0.008     1.035
      16,384  numpy          1,000     0.003     0.001     0.002     0.003     0.003     0.004     0.021     1.000
      16,384  mlx            1,000     0.022     0.001     0.020     0.021     0.022     0.022     0.033     0.160

      65,536  jax            1,000     0.003     0.001     0.002     0.003     0.003     0.003     0.023     4.253
      65,536  numba          1,000     0.006     0.001     0.005     0.005     0.006     0.006     0.009     2.011
      65,536  pytorch          100     0.011     0.002     0.007     0.010     0.011     0.012     0.016     1.104
      65,536  numpy            100     0.012     0.002     0.009     0.011     0.011     0.012     0.024     1.000
      65,536  mlx              100     0.080     0.002     0.077     0.079     0.079     0.080     0.087     0.149

     262,144  jax              100     0.009     0.001     0.007     0.008     0.009     0.009     0.014     4.411
     262,144  numba            100     0.021     0.002     0.019     0.020     0.021     0.022     0.031     1.832
     262,144  pytorch          100     0.034     0.009     0.026     0.029     0.031     0.033     0.065     1.155
     262,144  numpy            100     0.039     0.006     0.034     0.036     0.037     0.038     0.061     1.000
     262,144  mlx               10     0.264     0.004     0.260     0.261     0.263     0.266     0.273     0.147

   1,048,576  jax               10     0.047     0.005     0.042     0.044     0.045     0.046     0.059     4.178
   1,048,576  numba             10     0.088     0.004     0.084     0.084     0.087     0.091     0.097     2.220
   1,048,576  numpy             10     0.195     0.024     0.168     0.175     0.187     0.218     0.234     1.000
   1,048,576  pytorch           10     0.205     0.012     0.190     0.194     0.204     0.217     0.219     0.952
   1,048,576  mlx               10     1.182     0.015     1.167     1.168     1.177     1.195     1.204     0.165

   4,194,304  jax               10     0.249     0.047     0.202     0.227     0.230     0.249     0.358     3.832
   4,194,304  numba             10     0.358     0.035     0.324     0.333     0.340     0.393     0.417     2.663
   4,194,304  numpy             10     0.954     0.143     0.736     0.851     0.929     1.041     1.249     1.000
   4,194,304  pytorch           10     1.010     0.062     0.923     0.964     1.002     1.052     1.124     0.945
   4,194,304  mlx               10     4.874     0.138     4.726     4.786     4.858     4.895     5.235     0.196

GPU is a bit faster than that, for example on isoneutral:

$ python run.py benchmarks/isoneutral_mixing/ --device gpu -b mlx -b numpy

benchmarks.isoneutral_mixing
============================
Running on GPU

size          backend     calls     mean      stdev     min       25%       median    75%       max       Δ
------------------------------------------------------------------------------------------------------------------
       4,096  numpy         10,000     0.002     0.000     0.002     0.002     0.002     0.002     0.013     1.000
       4,096  mlx            1,000     0.009     0.002     0.007     0.008     0.008     0.009     0.042     0.202

      16,384  numpy          1,000     0.009     0.002     0.006     0.008     0.009     0.010     0.022     1.000
      16,384  mlx            1,000     0.009     0.002     0.007     0.008     0.009     0.011     0.045     0.974

      65,536  mlx            1,000     0.015     0.004     0.010     0.013     0.015     0.017     0.049     1.454
      65,536  numpy            100     0.022     0.003     0.018     0.020     0.021     0.023     0.033     1.000

     262,144  mlx              100     0.033     0.009     0.023     0.028     0.030     0.033     0.066     2.430
     262,144  numpy            100     0.081     0.013     0.067     0.068     0.077     0.091     0.115     1.000

   1,048,576  mlx              100     0.097     0.038     0.074     0.078     0.084     0.090     0.270     4.351
   1,048,576  numpy             10     0.421     0.050     0.359     0.376     0.409     0.459     0.501     1.000

   4,194,304  mlx               10     0.792     0.135     0.475     0.789     0.819     0.861     0.955     2.337
   4,194,304  numpy             10     1.852     0.039     1.777     1.836     1.849     1.879     1.920     1.000

Let me know in case you have any thoughts or concerns. I was planning to run these experiments more thoroughly and publish the results within a couple of weeks.

dionhaefner avatar Jan 04 '24 15:01 dionhaefner