mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Feature] Support fft based convolution

Open adonath opened this issue 1 year ago • 11 comments

It would be nice to have FFT based convolution supported in mlx. FFT bases convolution shows much better performance for large images / arrays and kernels. The FFT building blocks are already supported in mlx, so it is mostly a matter of combining them to a convolution operation.

adonath avatar Mar 08 '24 21:03 adonath

@awni

wouldn't mind looking into implementing this, what do you think ?

sebblanchet avatar Mar 09 '24 00:03 sebblanchet

One challenge here is that FFT is not yet supported on the GPU (in Metal). So you could use it but on the CPU it would almost certainly be much slower than our GPU convolution.

Also I think FFT-based convolution is more of an implementation detail. If there are some sizes that are slow for you, please share any benchmarks. We can then figure out the best way to make them faster (which may or may not require an FFT-based convolution).

awni avatar Mar 09 '24 05:03 awni

Thanks @awni and @sebblanchet! I did a quick implementation of a FFT based convolution in MLX:

def _centered(arr, newshape):
    newshape = mx.array(newshape)
    currshape = mx.array(arr.shape)

    startind = (currshape - newshape) // 2
    endind = startind + newshape
    myslice = [slice(startind[k].item(), endind[k].item()) for k in range(len(endind))]
    return arr[tuple(myslice)]


def convolve_fft(image, kernel, stream):
    """Convolve FFT for torch tensors"""
    image_2d, kernel_2d = image[0, 0], kernel[0, 0]

    shape = [image_2d.shape[i] + kernel_2d.shape[i] - 1 for i in range(image_2d.ndim)]

    image_ft = mx.fft.rfft2(image, s=shape, stream=stream)
    kernel_ft = mx.fft.rfft2(kernel, s=shape, stream=stream)
    result = mx.fft.irfft2(image_ft * kernel_ft, s=shape, stream=stream)
    return _centered(result, image.shape)

I also did a simple benchmark. It uses a random image of size 1024x1024 and varying kernel sizes. It compares mx.conv2d on the GPU and CPU respectively, the FFT based algorithm from above and for comparison Scipy's FFT convolution implementation. The result is the following:

mlx-conv-mini-benchmark

I think it follows exactly the expectation:

  • gpu is faster than cpu for native convolution
  • native convolution is faster for small kernel sizes
  • for large kernel sizes FFT clearly wins, because of the Nlog(N) scaling instead of N^2
  • Scipy's FFT convolve is faster, than mine (not too surprising)
  • the transition point where the cpu FFT becomes faster than native GPU is at a kernel size of >20-30.

In general I think it is still worth to have an FFT based convolution. For NNs with small kernels, there is no point. But there are many scientific applications that rely on large kernels (think of cross-correlations, convolution with pathological point spread functions, etc.)

I think it is worth re-opening.

adonath avatar Mar 11 '24 16:03 adonath

Ok sounds good! Thanks for the benchmarks, that's really interesting!

awni avatar Mar 12 '24 03:03 awni

One option is to update the CPU convolution to dispatch to an FFT implementation when the input sizes make sense. We would want to benchmark it in a few settings to be sure it's a strict improvement.

awni avatar Mar 12 '24 03:03 awni

Thanks for re-opening @awni!

One option is to update the CPU convolution to dispatch to an FFT implementation when the input sizes make sense. We would want to benchmark it in a few settings to be sure it's a strict improvement.

This is what Scipy has too, see https://github.com/scipy/scipy/blob/v1.12.0/scipy/signal/_signaltools.py#L1161 There is the option to measure or to actually compute the flops. Measuring only makes sense for repeated convolutions, but gives probably the most accurate results for arbitrary architectures. Looking at the Scipy code, it seems that computing the flops is maybe too complex. Or is there a general way to predict flops for mlx operations? (would be nice to have...)

In general the performance of MLX operations is probably much more predictable across the more homogeneous M architectures. So there could be a third option by just parametrizing the scaling laws based on empirical benchmarks or something similar...

adonath avatar Mar 12 '24 12:03 adonath

Here is the gist with the code for the benchmark: https://gist.github.com/adonath/3f16b30498c60f25cf1349792c15283c

adonath avatar Mar 12 '24 12:03 adonath

can I work on it?

rajveer43 avatar Aug 24 '24 11:08 rajveer43

Go for it!

awni avatar Aug 24 '24 13:08 awni

Go for it!

Its my first contribution in the repo can you please guide a bit what files I need to modify?

rajveer43 avatar Aug 24 '24 14:08 rajveer43