mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Performance] PyTorch (MPS) is faster than MLX in backward of convolution layer

Open arnold-yan opened this issue 1 year ago • 6 comments

Describe the bug Recently I profiled the neural network layer performance from MLX and compared with PyTorch. I found that although MLX forwarding is consistently faster than PyTorch, in some chips (M1 Pro, M1 Max), PyTorch is much faster (3x~6x) for convolution forward + backward. While in some chips such as M3 Max, MLX is faster than PyTorch. image

To Reproduce To reproduce this, I have two minimal examples. The networks just have several convolution layers. You may try these two scripts to verify the performance.

time_pytorch_mlx.zip

arnold-yan avatar Aug 08 '24 15:08 arnold-yan

Same benchmark on an M2 Ultra

average time of Pytorch: 7.20261025428772
average time of MLX: 2.34059739112854

awni avatar Aug 08 '24 16:08 awni

On M2Pro

  • average time of MLX: 4.472527980804443
  • average time of Pytorch: 10.073836088180542

aturker1 avatar Aug 12 '24 15:08 aturker1

On M3 Max

  • average time of MLX: 2.4813356399536133
  • average time of PyTorch: 7.1081931591033936

alwint3r avatar Aug 13 '24 01:08 alwint3r

Thanks for the benchmarks everyone! There is clearly an unexpected performance cliff on M1 machines here as MLX is substantially faster on M2+. We'll need to take a deeper look at that to figure out where it's coming from.

awni avatar Aug 13 '24 13:08 awni

On M1

  • average time of MLX: 30.113215446472168
  • average time of Pytorch: 15.948616743087769

pyvadev avatar Sep 13 '24 00:09 pyvadev

M3 Max: average time of MLX: 2.939736843109131 average time of Pytorch: 5.9829957485198975

jrp2014 avatar Sep 13 '24 20:09 jrp2014

@arnold-yan. I took a look at this benchmark.

The performance issue turns out to be from the gradient of the second call to nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.

A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).

Changing to:

        upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")

The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.

awni avatar Oct 27 '24 23:10 awni

@arnold-yan. I took a look at this benchmark.

The performance issue turns out to be from the gradient of the second call to nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.

A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).

Changing to:

        upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")

The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.

Hi @awni, thank you for figuring that out! It indeed a mistake that I intended to use "linear" here. However, when I tried to change to "linear" and run the test again on my M1 Pro MacBook Pro, I found that the running time even increased.

average time of MLX: 275.73419642448425 ms

I will find a M1 Max machine to verify this again.

arnold-yan avatar Oct 28 '24 15:10 arnold-yan

@arnold-yan you're right the benchmark is slower with linear 😓 . I had a mistake. Let me keep digging.

awni avatar Oct 28 '24 15:10 awni

Hi @arnold-yan https://github.com/ml-explore/mlx/pull/1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now:

Machine MLX PT
M1 Max 4.615 11.93
M3 Max 1.938 10.77

awni avatar Oct 29 '24 23:10 awni

Hi @arnold-yan #1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now:

Machine MLX PT M1 Max 4.615 11.93 M3 Max 1.938 10.77

Thanks @awni! Let me test on this PR. Now running time on M1 Pro is:

average time of MLX: 7.930095672607422 ms.

Good job!

arnold-yan avatar Oct 30 '24 07:10 arnold-yan

@arnold-yan. I took a look at this benchmark.

The performance issue turns out to be from the gradient of the second call to nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.

A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).

Changing to:

        upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")

The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.

By the way, I also tried this change on your PR on M1 Pro. Now the MLX time change to

average time of MLX: 18.393556356430054
average time of Pytorch: 15.363117218017578

It is a little bit slower than PyTorch. Does it mean we still have potential to accelerate MLX performance for linear mode?

arnold-yan avatar Oct 30 '24 13:10 arnold-yan