mlx
mlx copied to clipboard
[Feature] Add Metal support for FFT
Add support for Metal backend with FFT primitive as mentioned here https://github.com/ml-explore/mlx-examples/issues/249
I would like to take up this issue @awni. Please let me know if I can take this implementation up.
Do you have experience with GPU programming? This is not a trivial one so I would recommend starting with something simpler if not.
I have some experience from my intern at Amazon HPC and a GPU programming course. I am not completely familiar with Metal, but I am looking at its documentation. I'm familiar with the basics of FFT. Would that be enough to start on this?
It's hard for me to answer. I would recommend you take a look at parallel implementations of FFT. There is also this code which does FFT or Metal. Maybe a good place to start is to benchmark that code against a CPU implementation and just see how usable it might be?
We can use this thread to discuss your findings.
What's the status on this? It's blocking some audio processing stuff I'm trying to do. It looks like there's MPS implementation - any interests in adding this?
A tangential question to this, why doesn't the codebase leverage MPS more? IIUC it handles broadcasting between matrices correctly and has optimized ops for lower dimensions.
Just to chime in, if this operator is implemented, when I eventually finish implementing conv3d, I can accelerate the operation for larger input kernel shapes. FFT conv tends to outperform winowgrad conv for larger kernel shape inputs. Would be great to have as a feature to dispatch the appropriate conv kernel depending on shape.
I can also take a stab at this if nobody is working on it anymore @awni.
@AndreSlavescu See also: https://github.com/ml-explore/mlx/issues/811#issuecomment-1988934827
Could you reuse the code written for PyTorch? They added MPS support for FFT. https://github.com/pytorch/pytorch/pull/119670
It's not a great fit for us as we don't wrap MPSGraph and likely wouldn't make an exception for FFTs.