[Performance] PyTorch (MPS) is faster than MLX in backward of convolution layer
Describe the bug
Recently I profiled the neural network layer performance from MLX and compared with PyTorch. I found that although MLX forwarding is consistently faster than PyTorch, in some chips (M1 Pro, M1 Max), PyTorch is much faster (3x~6x) for convolution forward + backward. While in some chips such as M3 Max, MLX is faster than PyTorch.
To Reproduce To reproduce this, I have two minimal examples. The networks just have several convolution layers. You may try these two scripts to verify the performance.
Same benchmark on an M2 Ultra
average time of Pytorch: 7.20261025428772
average time of MLX: 2.34059739112854
On M2Pro
- average time of MLX: 4.472527980804443
- average time of Pytorch: 10.073836088180542
On M3 Max
- average time of MLX: 2.4813356399536133
- average time of PyTorch: 7.1081931591033936
Thanks for the benchmarks everyone! There is clearly an unexpected performance cliff on M1 machines here as MLX is substantially faster on M2+. We'll need to take a deeper look at that to figure out where it's coming from.
On M1
- average time of MLX: 30.113215446472168
- average time of Pytorch: 15.948616743087769
M3 Max: average time of MLX: 2.939736843109131 average time of Pytorch: 5.9829957485198975
@arnold-yan. I took a look at this benchmark.
The performance issue turns out to be from the gradient of the second call to nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.
A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).
Changing to:
upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")
The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.
@arnold-yan. I took a look at this benchmark.
The performance issue turns out to be from the gradient of the second call to
nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).
Changing to:
upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.
Hi @awni, thank you for figuring that out! It indeed a mistake that I intended to use "linear" here. However, when I tried to change to "linear" and run the test again on my M1 Pro MacBook Pro, I found that the running time even increased.
average time of MLX: 275.73419642448425 ms
I will find a M1 Max machine to verify this again.
@arnold-yan you're right the benchmark is slower with linear 😓 . I had a mistake. Let me keep digging.
Hi @arnold-yan https://github.com/ml-explore/mlx/pull/1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now:
| Machine | MLX | PT |
|---|---|---|
| M1 Max | 4.615 | 11.93 |
| M3 Max | 1.938 | 10.77 |
Hi @arnold-yan #1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now:
Machine MLX PT M1 Max 4.615 11.93 M3 Max 1.938 10.77
Thanks @awni! Let me test on this PR. Now running time on M1 Pro is:
average time of MLX: 7.930095672607422 ms.
Good job!
@arnold-yan. I took a look at this benchmark.
The performance issue turns out to be from the gradient of the second call to
nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).
Changing to:
upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.
By the way, I also tried this change on your PR on M1 Pro. Now the MLX time change to
average time of MLX: 18.393556356430054
average time of Pytorch: 15.363117218017578
It is a little bit slower than PyTorch. Does it mean we still have potential to accelerate MLX performance for linear mode?