mlx icon indicating copy to clipboard operation
mlx copied to clipboard

PyTorch (MPS) is faster than MLX for training and inference for ResNets and Transformers (tested on 2 tasks)

Open SarthakYadav opened this issue 1 year ago • 16 comments

Hi!

(I was originally going to file this issue on mlx-examples, but I figured performance is more relevant to this repository.)

I have had the chance to do some comparative benchmarking on PyTorch (MPS) and MLX for training and inference on two different types of models on two different tasks. Identical data pipeline (based on mlx-data) was used for this analysis. Throughput is presented as samples/sec. Mean and std dev is reported over 5 epochs.

Device -> Macbook Pro M1 16 GB

1. ResNets on CIFAR-10:

The code for benchmarking can be found here, and is simply an extension of the mlx-example recipe.

Model PyTorch
Train
PyTorch
Inference
MLX
Train
MLX
Inference
resnet20 4452.36±11.43 16591.67±222.13 416.69±0.40 2301.69±5.71
resnet44 2115.64±5.40 7732.72±26.00 184.05±0.74 1028.34±2.34
resnet110 839.60±3.06 3157.57±9.49 69.34±0.18 403.51±6.77

2. Keyword Spotting Transformer (KWT) on SpeechCommands:

The code for benchmarking can be found here, and is simply an extension of the mlx-example recipe.

Model PyTorch
Train
PyTorch
Inference
MLX
Train
MLX
Inference
kwt1 1485.61±7.42 3928.82±40.31 667.73±6.49 3009.28±82.50
kwt2 668.70±1.62 1881.66±10.29 395.56±5.10 1495.28±38.46

Both links have a COMPARISON.md where these tables can be found, along with some commentary.

Observations

  • For ResNet, training on PyTorch MPS is ~10-11x faster than MLX, while inference on PyTorch MPS is ~6x faster than MLX.
  • For KWT, training on PyTorch MPS is ~2x faster than MLX, while inference on PyTorch MPS is ~1.25x faster than MLX.

While all PyTorch tests were faster, the gap for ResNet is just too large. Based on the activity monitor, CPU/GPU utilization was quite similar. I couldn't do a more thorough investigation, for that, a Nvidia Visual Profiler like tool for profiling kernel calls will be needed (I don't know if one exists for Mac).

Thoughts?

SarthakYadav avatar Dec 21 '23 18:12 SarthakYadav

Yea these #s don't make sense. Seems likely there is some unexpected bottleneck in the codes you tested. Would you be up to try running this example. In previous benchmarks we have been faster than the corresponding torch implementation.

CC @jagrit06

awni avatar Dec 21 '23 18:12 awni

Sure. I'll check that out.

SarthakYadav avatar Dec 21 '23 18:12 SarthakYadav

I tested a transformer_lm with the following parameters

--context_size 128 --num_blocks 4 --dim 256 --num_heads 4

because the default is just too large for my MacBook. Didn't want to change much, so added tqdm on the for loop in the eval_fn.

And you were right, MLX is faster: ~70 it/s vs ~50 it/s for training, ~212 it/s vs 108 it/s for eval.

SarthakYadav avatar Dec 21 '23 18:12 SarthakYadav

Ok, this is probably good news. It really suggests there is some strange perf bottleneck in the two examples. It could be the convolutions (but that seems unlikely for the KWT example... since there is only one).

Will take a deeper look soon. Thanks for your help with the benchmarking!

awni avatar Dec 21 '23 18:12 awni

Yeah, I just ruled out the convolution in the case of KWT. Since it's a non-overlapping kernel it can be replaced with just a reshape and a Linear layer, the numbers remain more or less the same. It's something else.

Will take a deeper look soon. Thanks for your help with the benchmarking!

No problemo!

SarthakYadav avatar Dec 21 '23 19:12 SarthakYadav

I just ran the examples on a M2 Ultra, and can confirm that convolution isn't the major culprit - as @SarthakYadav said, switching it out with a reshape + linear doesn't change the numbers too much on MLX (though in my case doing that did also slow PyTorch just a bit)

I'll try to dig deeper and see where the performance issues crop up Thanks again for taking the time to share your results!

jagrit06 avatar Dec 21 '23 21:12 jagrit06

@jagrit06 for the speech kwt example this size matmul comes up and we are really slow compared to MPS on it (about 3x I think)

compare_filtered("matmul --size 64x25344 --size 25344x64")

awni avatar Dec 23 '23 20:12 awni

@SarthakYadav we also found a pretty sever performance cliff with one of our reduction kernels. I think fixing the matmul and the reduction for those cases should make the MLX version a lot faster.. .hopefully faster than the other version you tested 🤞

awni avatar Dec 23 '23 21:12 awni

It's awesome that you found that! 🤞 indeed, I suppose the performance issue with the reduction kernel explains the ResNet slowdown as well?

SarthakYadav avatar Dec 23 '23 22:12 SarthakYadav

I suppose the performance issue with the reduction kernel explains the ResNet slowdown as well?

I think it's part of it. I haven't looked at that benchmark as carefully yet, but we'll test it once these fixes are in and see..

awni avatar Dec 24 '23 00:12 awni

curious that even for simple mnist example thetorch shown better performance than mlx This is mlx results: python3 main.py --gpu Epoch 0: Test accuracy 0.602, Time 2.098 (s) Epoch 1: Test accuracy 0.846, Time 1.983 (s) Epoch 2: Test accuracy 0.916, Time 2.155 (s) ... Epoch 8: Test accuracy 0.957, Time 2.059 (s) Epoch 9: Test accuracy 0.961, Time 2.022 (s)

and following are pytorch: python3 torch_main.py --gpu Epoch 0: Test accuracy 0.596, Time 0.820 (s) Epoch 1: Test accuracy 0.865, Time 0.522 (s) Epoch 2: Test accuracy 0.904, Time 0.521 (s) ... Epoch 8: Test accuracy 0.964, Time 0.522 (s) Epoch 9: Test accuracy 0.941, Time 0.518 (s)

note that I had another heavy toch training running in parallel using MPX

in case of CPU both are (obviously) equal

p.s for both I used num_layers = 4 hidden_dim = 64 other params are same as in the repository

anfedoro avatar Jan 01 '24 20:01 anfedoro

The number you have for MLX are way too slow, something looks off there. For those parameters on my machine (M1 Max 32GB) MLX is twice as fast:

MLX:

Epoch 3: Test accuracy 0.936, Time 0.219 (s)

Torch:

Epoch 3: Test accuracy 0.933, Time 0.429 (s)

awni avatar Jan 05 '24 04:01 awni

Adding another data point here (M1 16 GB, first gen MB Pro) with mlx 0.0.7, PyTorch 2.1.0 and the benchmarking branch from https://github.com/SarthakYadav/mlx-examples.

MNIST

num_layers = 4 hidden_dim = 64 Other parameters as in repo.

MLX:

$ python main.py --gpu
...
Epoch 7: Test accuracy 0.951, Time 0.373 (s)
Epoch 8: Test accuracy 0.953, Time 0.375 (s)
Epoch 9: Test accuracy 0.964, Time 0.370 (s)

Torch:

$ python torch_main.py  --gpu
...
Epoch 7: Test accuracy 0.961, Time 0.531 (s)
Epoch 8: Test accuracy 0.962, Time 0.547 (s)
Epoch 9: Test accuracy 0.964, Time 0.545 (s)

CIFAR

All parameters as in repo.

MLX:

$ python main.py 
...
Number of params: 0.2697 M
Epoch 00 [000] | Train loss 3.015 | Train acc 0.082 | Throughput: 224.28 images/second
Epoch 00 [050] | Train loss 2.078 | Train acc 0.195 | Throughput: 234.05 images/second
Epoch 00 [100] | Train loss 2.075 | Train acc 0.227 | Throughput: 234.55 images/second
Epoch 00 [150] | Train loss 1.903 | Train acc 0.230 | Throughput: 234.74 images/second
Epoch: 0 | avg. Train loss 2.072 | avg. Train acc 0.215 | Throughput: 234.51 images/sec
Epoch: 0 | Test acc 0.280 | Throughput: 1005.77 images/sec

MLX with nn.BatchNorm (with track_running_stats=False) instead of nn.LayerNorm significantly improves classification performance and interestingly also throughput (assuming it is measured correctly). nn.BatchNorm with default params results in a ValueError (reported in https://github.com/ml-explore/mlx-examples/issues/233).

$ python main.py 
Number of params: 0.2697 M
Epoch 00 [000] | Train loss 2.657 | Train acc 0.059 | Throughput: 257.40 images/second
Epoch 00 [050] | Train loss 1.672 | Train acc 0.375 | Throughput: 275.69 images/second
Epoch 00 [100] | Train loss 1.466 | Train acc 0.484 | Throughput: 275.45 images/second
Epoch 00 [150] | Train loss 1.406 | Train acc 0.445 | Throughput: 276.19 images/second
Epoch: 0 | avg. Train loss 1.528 | avg. Train acc 0.435 | Throughput: 270.88 images/sec
Epoch: 0 | Test acc 0.525 | Throughput: 1409.26 images/sec

Torch:

$ python main_pytorch.py 
Epoch 00 [000] | Train loss 3.679 | Train acc 0.082 | Throughput: 561.25 images/second
Epoch 00 [050] | Train loss 1.710 | Train acc 0.367 | Throughput: 2115.42 images/second
Epoch 00 [100] | Train loss 1.543 | Train acc 0.398 | Throughput: 2144.91 images/second
Epoch 00 [150] | Train loss 1.470 | Train acc 0.449 | Throughput: 2143.19 images/second
Epoch: 0 | avg. Train loss 1.590 | avg. Train acc 0.417 | Throughput: 2117.98 images/sec
Epoch: 0 | Test acc 0.494 | Throughput: 7410.47 images/sec

menzHSE avatar Jan 05 '24 10:01 menzHSE

Did you notice this test MPS or MLX for Domestic AI? ? Did MLX use any features that torch.mps don't use to boost the performance? @awni

SunnyBeike avatar Jan 18 '24 12:01 SunnyBeike

@SunnyBeike In my experiments with small (image) CNNs, see above, I am not seeing different GPU frequencies according to asitop. For both torch and mlx, I am getting near 100% GPU @ 1278 MHz on a M1 Macbook Pro.

menzHSE avatar Jan 18 '24 12:01 menzHSE

@SunnyBeike In my experiments with small (image) CNNs, see above, I am not seeing different GPU frequencies according to asitop. For both torch and mlx, I am getting near 100% GPU @ 1278 MHz on a M1 Macbook Pro.

MLX and PyTorch show different performances on CIFAR and MNIST in your tests, quite confused... I found MLX implemented its kernels and memory management, while torch.mps seems to rely on MPS. In my opinion, MPS should also utilize the Metal advantages (i.e., unified memory), So I wonder if the performance gap was caused by the backend kernels, and the unified memory in MLX mainly contributes to running larger model instead of better performance. Am I correct? @menzHSE @awni

SunnyBeike avatar Jan 19 '24 03:01 SunnyBeike

Let's close this and open a new issue more specific to optimizing convolutional network training which is one of the major slow points still in MLX.

awni avatar Aug 08 '24 15:08 awni