Fast Metal FFT for all N
Proposed changes
A feature-complete Metal FFT that's faster than both CPU and PyTorch MPS in the majority of 1D cases.
Fully functional, but still needs some clean up.
Resolves #399.
Supports
- All values of N (tested up to 2^20)
- Real and Inverse FFTs:
fft,ifft,rfft,irfft - ND FFTs:
fft2,ifft2,rfft2,irfft2,fftn,ifftn,rfftn,irfftn
Performance
For N < 1024, 1D FFTs on my M1 Max:
- Faster than PyTorch MPS for ~90% of FFT sizes
- ~1.5X higher average throughput than PyTorch MPS
- ~13X higher average throughput than CPU
We're only behind MPS on some multiples of 7 and all multiples of 11 and 13:
Our Bluestein's implementation is significantly more efficient for N < 1024:
Note: For the sake of time, I ran at a slightly lower batch size than is required to max out the bandwidth for the powers of 2. I'll run a full one shortly, but in my experiments so far the relative speeds seem to hold.
Implementation Details
For N <= 2048 whose prime factors are all <= 7:
- Stockham's algorithm in threadgroup memory with a mixed-radix, out of place FFT
- Codelets for radix-2,3,4,5,7
- Kernels are specialized at runtime for each N with Metal function constants
- Threadgroup batching for small N to improve performance
For all other N <= 1024:
- A fused Bluestein's algorithm implementation
- Bluestein twiddles are computed on CPU in float64 to maintain acceptable precision for the overall algorithm.
For N > 1024:
- The four step FFT algorithm
- If N has prime factors
> 1024, we use a manual version of Bluestein's implemented with MLX ops
RFFT:
- We implement a custom kernel for real FFTs that uses a trick to perform two at a time, doubling the bandwidth
Areas for Improvement
Codelet optimizations and additions
The radix codelets are extremely naive currently and could be replaced with hand-tuned or compiled ones that perform fewer than O(N^2) operations. We should also add radix11 and radix13 codelets to match MPS and VkFFT.
Performance on ND and four step FFT cases
These have quite a few unnecessary copies currently. A fused implementation incorporating the transpose and twiddle factors would bring us closer to the max bandwidth.
Accuracy
Accuracy is comparable to MPS' implementation but about an order of magnitude behind pocketfft. More careful twiddle factor computation inspired by pocketfft could help here. Precision also suffers a bit on very large N. Computing the twiddle factors in float64 as we do with Bluestein's would help.
IRFFT
irfft on GPU currently only works for outputs of rfft (there are a couple exceptions in the tests to account for this).
Convolution theorem
The fused Bluestein's implementation contains a convolution implemented with FFTs via the convolution theorem. For larger kernel sizes we might want to adapt this and add it to the main convolution implementation as suggested in #811.
Checklist
Put an x in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes - [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the necessary documentation (if needed)
Do you mind rebasing this @barronalex ?