mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Feature] Add Metal support for FFT

Open awni opened this issue 6 months ago • 9 comments

Add support for Metal backend with FFT primitive as mentioned here https://github.com/ml-explore/mlx-examples/issues/249

awni avatar Jan 08 '24 02:01 awni

I would like to take up this issue @awni. Please let me know if I can take this implementation up.

aneeshk1412 avatar Feb 01 '24 20:02 aneeshk1412

Do you have experience with GPU programming? This is not a trivial one so I would recommend starting with something simpler if not.

awni avatar Feb 01 '24 21:02 awni

I have some experience from my intern at Amazon HPC and a GPU programming course. I am not completely familiar with Metal, but I am looking at its documentation. I'm familiar with the basics of FFT. Would that be enough to start on this?

aneeshk1412 avatar Feb 02 '24 00:02 aneeshk1412

It's hard for me to answer. I would recommend you take a look at parallel implementations of FFT. There is also this code which does FFT or Metal. Maybe a good place to start is to benchmark that code against a CPU implementation and just see how usable it might be?

We can use this thread to discuss your findings.

awni avatar Feb 02 '24 00:02 awni

What's the status on this? It's blocking some audio processing stuff I'm trying to do. It looks like there's MPS implementation - any interests in adding this?

A tangential question to this, why doesn't the codebase leverage MPS more? IIUC it handles broadcasting between matrices correctly and has optimized ops for lower dimensions.

Rifur13 avatar Feb 12 '24 17:02 Rifur13

Just to chime in, if this operator is implemented, when I eventually finish implementing conv3d, I can accelerate the operation for larger input kernel shapes. FFT conv tends to outperform winowgrad conv for larger kernel shape inputs. Would be great to have as a feature to dispatch the appropriate conv kernel depending on shape.

I can also take a stab at this if nobody is working on it anymore @awni.

AndreSlavescu avatar Feb 16 '24 20:02 AndreSlavescu

@AndreSlavescu See also: https://github.com/ml-explore/mlx/issues/811#issuecomment-1988934827

adonath avatar Mar 11 '24 20:03 adonath

Could you reuse the code written for PyTorch? They added MPS support for FFT. https://github.com/pytorch/pytorch/pull/119670

avinashahuja avatar Mar 14 '24 04:03 avinashahuja

It's not a great fit for us as we don't wrap MPSGraph and likely wouldn't make an exception for FFTs.

awni avatar Mar 14 '24 04:03 awni