Feature complete Metal FFT
Proposed changes
A feature complete GPU FFT implementation in Metal.
Supports
- All
n < 2^20 - Real and Inverse FFTs:
fft,ifft,rfft,irfft - ND FFTs:
fft2,ifft2,rfft2,irfft2,fftn,ifftn,rfftn,irfftn
Algorithms
- A mixed radix out of place Stockham FFT for
nwhere all prime factorsphave2 =< p <= 13. - Rader's Algorithm for
nwith one prime factorp > 13wherep-1can be computed via Stockholm. - Bluestein's Algorithm for all other
n. - Four Step FFT for
n > 4096when the FFT can no longer be done purely in GPU shared memory.
Performance
For 2 <= n < 512, 1D complex to complex FFTs on my M1 Max, the average bandwidths are:
MLX GPU: 142.7 GB/s
MPS GPU: 69.3 GB/s
MLX CPU: 5.9 GB/s
So this implementation is about 2x faster than MPS on average and about 24x faster than CPU MLX which uses pocketfft.
This implementation does specialize for different values of n with Metal function constants so it will have more overhead than MPS on the first call for new Stockham/Rader sizes.
Very impressive perf!
Regarding the design, there is a big style difference from other MLX ops which we should change if possible. Basically you do the dispatch at the op-level rather than the Primitive level. I see how this might be easier as you have access to all the ops you need for different FFT algorithms, but I don't think we should do it this way. The compute graph should be more independent of the implementation details. Also, I don't think it makes sense for the FFT plans themselves should not be part of the compute graph (implementation detail).
This redesign may require some changes to our existing backend to make it workable for you to use the requisite back-end ops from the FFT primitive's eval_gpu.
That makes sense to me, it did feel a little anti-pattern bloating out the graph but the MLX api is just really convenient! Let me give the re-write a go today, I don't think it'll be too bad.
That makes sense to me, it did feel a little anti-pattern bloating out the graph but the MLX api is just really convenient! Let me give the re-write a go today, I don't think it'll be too bad.
We have really bad support for doing stuff on arrays inside primitives (MLX wasn't really designed with that in mind 😓 ). But I think we can improve it a lot if needed.
OK that took a little while but I think the FFTs are in a reasonable state now:
- All the GPU planning/running logic has been moved to
metal/fft.cppso we're not bloating the graph at all - Added a no transpose four step FFT implementation so big powers of two are fast now (~100-140GB/s on M1 Max)
- Added FFT to the JIT
- Refactored the reading/writing so we now support RFFT/IRFFT for Stockham/Rader/Bluestein/4 Step directly in the kernel