mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Feature complete Metal FFT

Open barronalex opened this issue 1 year ago • 3 comments

Proposed changes

A feature complete GPU FFT implementation in Metal.

Supports

  • All n < 2^20
  • Real and Inverse FFTs: fft, ifft, rfft, irfft
  • ND FFTs: fft2, ifft2, rfft2, irfft2, fftn, ifftn, rfftn, irfftn

Algorithms

  • A mixed radix out of place Stockham FFT for n where all prime factors p have 2 =< p <= 13.
  • Rader's Algorithm for n with one prime factor p > 13 where p-1 can be computed via Stockholm.
  • Bluestein's Algorithm for all other n.
  • Four Step FFT for n > 4096 when the FFT can no longer be done purely in GPU shared memory.

Performance

For 2 <= n < 512, 1D complex to complex FFTs on my M1 Max, the average bandwidths are:

MLX GPU: 142.7 GB/s
MPS GPU: 69.3 GB/s
MLX CPU: 5.9 GB/s

So this implementation is about 2x faster than MPS on average and about 24x faster than CPU MLX which uses pocketfft.

This implementation does specialize for different values of n with Metal function constants so it will have more overhead than MPS on the first call for new Stockham/Rader sizes.

Radix 2-13 Other Sizes

barronalex avatar May 10 '24 20:05 barronalex

Very impressive perf!

Regarding the design, there is a big style difference from other MLX ops which we should change if possible. Basically you do the dispatch at the op-level rather than the Primitive level. I see how this might be easier as you have access to all the ops you need for different FFT algorithms, but I don't think we should do it this way. The compute graph should be more independent of the implementation details. Also, I don't think it makes sense for the FFT plans themselves should not be part of the compute graph (implementation detail).

This redesign may require some changes to our existing backend to make it workable for you to use the requisite back-end ops from the FFT primitive's eval_gpu.

awni avatar May 13 '24 15:05 awni

That makes sense to me, it did feel a little anti-pattern bloating out the graph but the MLX api is just really convenient! Let me give the re-write a go today, I don't think it'll be too bad.

barronalex avatar May 13 '24 16:05 barronalex

That makes sense to me, it did feel a little anti-pattern bloating out the graph but the MLX api is just really convenient! Let me give the re-write a go today, I don't think it'll be too bad.

We have really bad support for doing stuff on arrays inside primitives (MLX wasn't really designed with that in mind 😓 ). But I think we can improve it a lot if needed.

awni avatar May 13 '24 16:05 awni

OK that took a little while but I think the FFTs are in a reasonable state now:

  • All the GPU planning/running logic has been moved to metal/fft.cpp so we're not bloating the graph at all
  • Added a no transpose four step FFT implementation so big powers of two are fast now (~100-140GB/s on M1 Max)
  • Added FFT to the JIT
  • Refactored the reading/writing so we now support RFFT/IRFFT for Stockham/Rader/Bluestein/4 Step directly in the kernel

barronalex avatar Jun 05 '24 16:06 barronalex