mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] FFT fails on certain array lengths

Open cdcapano opened this issue 10 months ago • 3 comments

Describe the bug Doing an FFT on array lengths 2^(21) and 2^(22) results in a kernel failure, but larger array sizes work.

To Reproduce

A simple script to reproduce:

import mlx.core as mx
size = int(2**21)
x = mx.ones(size)
mx.eval(mx.fft.fft(x, stream=mx.gpu))

This will result in the following error:

Terminating due to uncaught exception: [metal::Device] Unable to load function four_step_mem_8192_float2_float2_0_false
Function four_step_mem_8192_float2_float2_0_false was not found in the library

Abort trap: 6

A similar thing happens for an array that is 2**22 long. However, the code succeeds for arrays that have length 2**23, 2**24, 2**25, etc., up to 2**28. (I don't have enough memory to test beyond that.) By "succeed" I mean the function runs without failure. I haven't checked that the output is actually correct.

Expected behavior The FFT should work for 2**21 and 2**22 if larger array sizes work. At the very least, the error should be caught appropriately with a more graceful exit.

Desktop (please complete the following information):

  • OS Version: MacOS 15.1.1
  • Version 0.22.0

Additional context Digging into the code a bit I can see why it's failing. For a size of 2**21, plan.n1 here will get set to 2048. Later on, that will cause threadgroup_mem_size to get set to 8192 here. However, I don't know why that doesn't cause the assert at line 641 to raise an error.

I see the comment at line 640 that // FFTs up to 2^20 are currently supported, so I'm not sure why the 2^23 FFTs are running. Even if the assert worked properly, why the limit of 2^20? In the research application we're trying to use this for we will be evaluating arrays of 2^21 - 2^25, so it would be ideal if these array sizes could be handled.

cdcapano avatar Jan 27 '25 18:01 cdcapano

Indeed looks like a bug.

awni avatar Jan 27 '25 23:01 awni

why assert not works assert works in debug mode, not release mode. this command makes assert work: CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace --debug

why the 2^21 error happens The algorithm needs the FFT kernel 'four_step_mem_8192_float2_float2_0_false' in this case. but the library cannot find it because it is not instantiated. add instantiate_ffts(8192) here would instantiate it.

but as explained here, this kernel won't fit into 32KB of threadgroup memory. Most Mac devices' Metal GPUs have this limitation.

but why 2^23 works fft algorithm breaks large fft into smaller fft recursively. for 2^23, it breaks into (128 x 64) x 1024. for 2^21, it breaks into 2048 x 1024. algorithm here chooses not to break 2048 further since it is not larger than MAX_STOCKHAM_FFT_SIZE, which is set to be 4096. And that is also why 2^22 = 4096 x 1024 fails.

hriverg avatar Feb 16 '25 07:02 hriverg

I will fix this properly soon but in the meantime here are two options:

  1. Run on CPU
  2. Run a four step FFT with MLX ops (this is roughly what I'm going to implement in the C++ backend):
import math
import mlx.core as mx

def four_step_fft(x, axis: int = 0):
    n = x.shape[axis]
    assert n & (n - 1) == 0, "Only supports powers of two"
    log_n = math.log2(n) / 2
    n1, n2 = 2**(math.ceil(log_n)), 2**(math.floor(log_n))
    orig_shape = x.shape
    shape = x.shape[:axis] + (n1, n2) + x.shape[axis+1:]
    x = x.reshape(shape)
    ij = mx.arange(n1)[:, mx.newaxis] * mx.arange(n2)[mx.newaxis]
    twiddles = mx.exp(mx.array(-2j * mx.pi * ij / n))
    step_one = mx.fft.fft(x, axis=axis) * twiddles
    step_two = mx.fft.fft(mx.swapaxes(step_one, axis, axis + 1), axis=axis)
    return step_two.reshape((orig_shape))

The problem at the moment is that the strided four step FFT implementation runs out of thread group memory when the constituent FFTs are larger than 1024 (hence the 1024*1024=2**20 limit). I'll implement a nested four step fft as above to fix this.

barronalex avatar Feb 16 '25 15:02 barronalex