mlx
mlx copied to clipboard
PyTorch (MPS) is faster than MLX for training and inference for ResNets and Transformers (tested on 2 tasks)
Hi!
(I was originally going to file this issue on mlx-examples
, but I figured performance is more relevant to this repository.)
I have had the chance to do some comparative benchmarking on PyTorch (MPS) and MLX for training and inference on two different types of models on two different tasks. Identical data pipeline (based on mlx-data
) was used for this analysis. Throughput is presented as samples/sec
. Mean and std dev is reported over 5 epochs.
Device -> Macbook Pro M1 16 GB
1. ResNets on CIFAR-10:
The code for benchmarking can be found here, and is simply an extension of the mlx-example recipe.
Model | PyTorch Train |
PyTorch Inference |
MLX Train |
MLX Inference |
---|---|---|---|---|
resnet20 | 4452.36±11.43 | 16591.67±222.13 | 416.69±0.40 | 2301.69±5.71 |
resnet44 | 2115.64±5.40 | 7732.72±26.00 | 184.05±0.74 | 1028.34±2.34 |
resnet110 | 839.60±3.06 | 3157.57±9.49 | 69.34±0.18 | 403.51±6.77 |
2. Keyword Spotting Transformer (KWT) on SpeechCommands:
The code for benchmarking can be found here, and is simply an extension of the mlx-example recipe.
Model | PyTorch Train |
PyTorch Inference |
MLX Train |
MLX Inference |
---|---|---|---|---|
kwt1 | 1485.61±7.42 | 3928.82±40.31 | 667.73±6.49 | 3009.28±82.50 |
kwt2 | 668.70±1.62 | 1881.66±10.29 | 395.56±5.10 | 1495.28±38.46 |
Both links have a COMPARISON.md
where these tables can be found, along with some commentary.
Observations
- For ResNet, training on PyTorch MPS is ~10-11x faster than MLX, while inference on PyTorch MPS is ~6x faster than MLX.
- For KWT, training on PyTorch MPS is ~2x faster than MLX, while inference on PyTorch MPS is ~1.25x faster than MLX.
While all PyTorch tests were faster, the gap for ResNet is just too large. Based on the activity monitor, CPU/GPU utilization was quite similar. I couldn't do a more thorough investigation, for that, a Nvidia Visual Profiler like tool for profiling kernel calls will be needed (I don't know if one exists for Mac).
Thoughts?
Yea these #s don't make sense. Seems likely there is some unexpected bottleneck in the codes you tested. Would you be up to try running this example. In previous benchmarks we have been faster than the corresponding torch implementation.
CC @jagrit06
Sure. I'll check that out.
I tested a transformer_lm with the following parameters
--context_size 128 --num_blocks 4 --dim 256 --num_heads 4
because the default is just too large for my MacBook. Didn't want to change much, so added tqdm on the for loop in the eval_fn
.
And you were right, MLX is faster: ~70 it/s vs ~50 it/s for training, ~212 it/s vs 108 it/s for eval.
Ok, this is probably good news. It really suggests there is some strange perf bottleneck in the two examples. It could be the convolutions (but that seems unlikely for the KWT example... since there is only one).
Will take a deeper look soon. Thanks for your help with the benchmarking!
Yeah, I just ruled out the convolution in the case of KWT. Since it's a non-overlapping kernel it can be replaced with just a reshape and a Linear layer, the numbers remain more or less the same. It's something else.
Will take a deeper look soon. Thanks for your help with the benchmarking!
No problemo!
I just ran the examples on a M2 Ultra, and can confirm that convolution isn't the major culprit - as @SarthakYadav said, switching it out with a reshape + linear doesn't change the numbers too much on MLX (though in my case doing that did also slow PyTorch just a bit)
I'll try to dig deeper and see where the performance issues crop up Thanks again for taking the time to share your results!
@jagrit06 for the speech kwt example this size matmul comes up and we are really slow compared to MPS on it (about 3x I think)
compare_filtered("matmul --size 64x25344 --size 25344x64")
@SarthakYadav we also found a pretty sever performance cliff with one of our reduction kernels. I think fixing the matmul and the reduction for those cases should make the MLX version a lot faster.. .hopefully faster than the other version you tested 🤞
It's awesome that you found that! 🤞 indeed, I suppose the performance issue with the reduction kernel explains the ResNet slowdown as well?
I suppose the performance issue with the reduction kernel explains the ResNet slowdown as well?
I think it's part of it. I haven't looked at that benchmark as carefully yet, but we'll test it once these fixes are in and see..
curious that even for simple mnist example thetorch shown better performance than mlx This is mlx results: python3 main.py --gpu Epoch 0: Test accuracy 0.602, Time 2.098 (s) Epoch 1: Test accuracy 0.846, Time 1.983 (s) Epoch 2: Test accuracy 0.916, Time 2.155 (s) ... Epoch 8: Test accuracy 0.957, Time 2.059 (s) Epoch 9: Test accuracy 0.961, Time 2.022 (s)
and following are pytorch: python3 torch_main.py --gpu Epoch 0: Test accuracy 0.596, Time 0.820 (s) Epoch 1: Test accuracy 0.865, Time 0.522 (s) Epoch 2: Test accuracy 0.904, Time 0.521 (s) ... Epoch 8: Test accuracy 0.964, Time 0.522 (s) Epoch 9: Test accuracy 0.941, Time 0.518 (s)
note that I had another heavy toch training running in parallel using MPX
in case of CPU both are (obviously) equal
p.s for both I used num_layers = 4 hidden_dim = 64 other params are same as in the repository
The number you have for MLX are way too slow, something looks off there. For those parameters on my machine (M1 Max 32GB) MLX is twice as fast:
MLX:
Epoch 3: Test accuracy 0.936, Time 0.219 (s)
Torch:
Epoch 3: Test accuracy 0.933, Time 0.429 (s)
Adding another data point here (M1 16 GB, first gen MB Pro) with mlx 0.0.7, PyTorch 2.1.0 and the benchmarking
branch from https://github.com/SarthakYadav/mlx-examples.
MNIST
num_layers = 4 hidden_dim = 64 Other parameters as in repo.
MLX:
$ python main.py --gpu
...
Epoch 7: Test accuracy 0.951, Time 0.373 (s)
Epoch 8: Test accuracy 0.953, Time 0.375 (s)
Epoch 9: Test accuracy 0.964, Time 0.370 (s)
Torch:
$ python torch_main.py --gpu
...
Epoch 7: Test accuracy 0.961, Time 0.531 (s)
Epoch 8: Test accuracy 0.962, Time 0.547 (s)
Epoch 9: Test accuracy 0.964, Time 0.545 (s)
CIFAR
All parameters as in repo.
MLX:
$ python main.py
...
Number of params: 0.2697 M
Epoch 00 [000] | Train loss 3.015 | Train acc 0.082 | Throughput: 224.28 images/second
Epoch 00 [050] | Train loss 2.078 | Train acc 0.195 | Throughput: 234.05 images/second
Epoch 00 [100] | Train loss 2.075 | Train acc 0.227 | Throughput: 234.55 images/second
Epoch 00 [150] | Train loss 1.903 | Train acc 0.230 | Throughput: 234.74 images/second
Epoch: 0 | avg. Train loss 2.072 | avg. Train acc 0.215 | Throughput: 234.51 images/sec
Epoch: 0 | Test acc 0.280 | Throughput: 1005.77 images/sec
MLX with nn.BatchNorm
(with track_running_stats=False
) instead of nn.LayerNorm
significantly improves classification performance and interestingly also throughput (assuming it is measured correctly). nn.BatchNorm
with default params results in a ValueError
(reported in https://github.com/ml-explore/mlx-examples/issues/233).
$ python main.py
Number of params: 0.2697 M
Epoch 00 [000] | Train loss 2.657 | Train acc 0.059 | Throughput: 257.40 images/second
Epoch 00 [050] | Train loss 1.672 | Train acc 0.375 | Throughput: 275.69 images/second
Epoch 00 [100] | Train loss 1.466 | Train acc 0.484 | Throughput: 275.45 images/second
Epoch 00 [150] | Train loss 1.406 | Train acc 0.445 | Throughput: 276.19 images/second
Epoch: 0 | avg. Train loss 1.528 | avg. Train acc 0.435 | Throughput: 270.88 images/sec
Epoch: 0 | Test acc 0.525 | Throughput: 1409.26 images/sec
Torch:
$ python main_pytorch.py
Epoch 00 [000] | Train loss 3.679 | Train acc 0.082 | Throughput: 561.25 images/second
Epoch 00 [050] | Train loss 1.710 | Train acc 0.367 | Throughput: 2115.42 images/second
Epoch 00 [100] | Train loss 1.543 | Train acc 0.398 | Throughput: 2144.91 images/second
Epoch 00 [150] | Train loss 1.470 | Train acc 0.449 | Throughput: 2143.19 images/second
Epoch: 0 | avg. Train loss 1.590 | avg. Train acc 0.417 | Throughput: 2117.98 images/sec
Epoch: 0 | Test acc 0.494 | Throughput: 7410.47 images/sec
Did you notice this test MPS or MLX for Domestic AI? ? Did MLX use any features that torch.mps don't use to boost the performance? @awni
@SunnyBeike In my experiments with small (image) CNNs, see above, I am not seeing different
GPU frequencies according to asitop
. For both torch and mlx, I am getting near 100% GPU @ 1278 MHz on a M1 Macbook Pro.
@SunnyBeike In my experiments with small (image) CNNs, see above, I am not seeing different GPU frequencies according to
asitop
. For both torch and mlx, I am getting near 100% GPU @ 1278 MHz on a M1 Macbook Pro.
MLX and PyTorch show different performances on CIFAR and MNIST in your tests, quite confused... I found MLX implemented its kernels and memory management, while torch.mps seems to rely on MPS. In my opinion, MPS should also utilize the Metal advantages (i.e., unified memory), So I wonder if the performance gap was caused by the backend kernels, and the unified memory in MLX mainly contributes to running larger model instead of better performance. Am I correct? @menzHSE @awni
Let's close this and open a new issue more specific to optimizing convolutional network training which is one of the major slow points still in MLX.