mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Fast Metal FFT for all N

Open barronalex opened this issue 1 year ago • 1 comments

Proposed changes

A feature-complete Metal FFT that's faster than both CPU and PyTorch MPS in the majority of 1D cases.

Fully functional, but still needs some clean up.

Resolves #399.

Supports

  • All values of N (tested up to 2^20)
  • Real and Inverse FFTs: fft, ifft, rfft, irfft
  • ND FFTs: fft2, ifft2, rfft2, irfft2, fftn, ifftn, rfftn, irfftn

Performance

For N < 1024, 1D FFTs on my M1 Max:

  • Faster than PyTorch MPS for ~90% of FFT sizes
  • ~1.5X higher average throughput than PyTorch MPS
  • ~13X higher average throughput than CPU

All

We're only behind MPS on some multiples of 7 and all multiples of 11 and 13: Radix 2-13 Our Bluestein's implementation is significantly more efficient for N < 1024:

Bluestein's

Note: For the sake of time, I ran at a slightly lower batch size than is required to max out the bandwidth for the powers of 2. I'll run a full one shortly, but in my experiments so far the relative speeds seem to hold.

Implementation Details

For N <= 2048 whose prime factors are all <= 7:

  • Stockham's algorithm in threadgroup memory with a mixed-radix, out of place FFT
  • Codelets for radix-2,3,4,5,7
  • Kernels are specialized at runtime for each N with Metal function constants
  • Threadgroup batching for small N to improve performance

For all other N <= 1024:

  • A fused Bluestein's algorithm implementation
  • Bluestein twiddles are computed on CPU in float64 to maintain acceptable precision for the overall algorithm.

For N > 1024:

  • The four step FFT algorithm
  • If N has prime factors > 1024, we use a manual version of Bluestein's implemented with MLX ops

RFFT:

  • We implement a custom kernel for real FFTs that uses a trick to perform two at a time, doubling the bandwidth

Areas for Improvement

Codelet optimizations and additions

The radix codelets are extremely naive currently and could be replaced with hand-tuned or compiled ones that perform fewer than O(N^2) operations. We should also add radix11 and radix13 codelets to match MPS and VkFFT.

Performance on ND and four step FFT cases

These have quite a few unnecessary copies currently. A fused implementation incorporating the transpose and twiddle factors would bring us closer to the max bandwidth.

Accuracy

Accuracy is comparable to MPS' implementation but about an order of magnitude behind pocketfft. More careful twiddle factor computation inspired by pocketfft could help here. Precision also suffers a bit on very large N. Computing the twiddle factors in float64 as we do with Bluestein's would help.

IRFFT

irfft on GPU currently only works for outputs of rfft (there are a couple exceptions in the tests to account for this).

Convolution theorem

The fused Bluestein's implementation contains a convolution implemented with FFTs via the convolution theorem. For larger kernel sizes we might want to adapt this and add it to the main convolution implementation as suggested in #811.

Checklist

Put an x in the boxes that apply.

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] I have updated the necessary documentation (if needed)

barronalex avatar Apr 11 '24 06:04 barronalex

Do you mind rebasing this @barronalex ?

awni avatar Apr 12 '24 04:04 awni