mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] Provide CUDA implementation of FFT

Open ekamperi opened this issue 3 months ago • 8 comments

Describe the bug Attempting an FFT calculation with the CUDA backend throws a runtime error. Not sure if it's a bug or a feature request.

To Reproduce

snapshots-gpu-4000adax1-20gb-tor1# cat reproduce_fft_error.py
#!/usr/bin/env python3
"""
Minimal reproduction of MLX FFT CUDA backend error.

This script demonstrates the error: "FFT has no CUDA implementation."
"""

import mlx.core as mx

def main():
    print("Testing MLX FFT with CUDA backend...")
    print(f"MLX version: {mx.__version__ if hasattr(mx, '__version__') else 'unknown'}")
    print(f"Default device: {mx.default_device()}")

    # Create a simple 1D signal
    x = mx.array([1.0, 2.0, 3.0, 4.0], dtype=mx.float32)
    print(f"Input array: {x}")

    try:
        # Attempt FFT operation that triggers the CUDA error
        result = mx.fft.fft(x)
        print(f"FFT result: {result}")
    except Exception as e:
        print(f"Error occurred: {e}")
        print(f"Error type: {type(e).__name__}")
        raise

if __name__ == "__main__":
    main()
snapshots-gpu-4000adax1-20gb-tor1# uv run python reproduce_fft_error.py
Testing MLX FFT with CUDA backend...
MLX version: 0.29.0
Default device: Device(gpu, 0)
Input array: array([1, 2, 3, 4], dtype=float32)
Error occurred: FFT has no CUDA implementation.
Error type: RuntimeError
Traceback (most recent call last):
  File "/root/logfiles/reproduce_fft_error.py", line 30, in <module>
    main()
  File "/root/logfiles/reproduce_fft_error.py", line 23, in main
    print(f"FFT result: {result}")
                        ^^^^^^^^
RuntimeError: FFT has no CUDA implementation.
snapshots-gpu-4000adax1-20gb-tor1#

Expected behavior FFT should work with CUDA backend, or a fallback computation in the CPU should occur on the fly.

Desktop

snapshots-gpu-4000adax1-20gb-tor1# nvidia-smi
Mon Sep  1 14:06:40 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08              Driver Version: 575.57.08      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX 4000 Ada Gene...    On  |   00000000:01:00.0 Off |                  Off |
| 30%   33C    P8             13W /  130W |       2MiB /  20475MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
snapshots-gpu-4000adax1-20gb-tor1#
snapshots-gpu-4000adax1-20gb-tor1# lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 22.04.4 LTS
Release:        22.04
Codename:       jammy
snapshots-gpu-4000adax1-20gb-tor1#
snapshots-gpu-4000adax1-20gb-tor1# uv pip show mlx
Name: mlx
Version: 0.29.0
Location: /root/logfiles/.venv/lib/python3.12/site-packages
Requires:
Required-by:
snapshots-gpu-4000adax1-20gb-tor1#

ekamperi avatar Sep 01 '25 14:09 ekamperi

I think wrapping cuFFT is a good solution for this https://developer.nvidia.com/cufft

awni avatar Sep 03 '25 19:09 awni

Currently working on a branch doing the CuFFT solution, I'm using our already implemented cuBLAS wrap currently being used for mlx/backend/cuda/matmul.cpp as reference on how are we handling CUDA libs wrapping

Maalvi14 avatar Sep 05 '25 01:09 Maalvi14

Awesome, thanks!

awni avatar Sep 05 '25 01:09 awni

Hi, CUDA Python tech lead here 👋

cuFFT has an official Python solution which is exposed as part of nvmath-python, home for NVIDIA CPU & GPU math libraries. It should have everything that MLX needs. In fact, we will be replacing CuPy's FFT support using nvmath-python (https://github.com/cupy/cupy/issues/9237) because 1. it's official, 2. it's more performant, 3. the installation footprint is small. This will offer a much easier route for MLX to enable FFT support on NVIDIA GPUs, without worrying about Python bindings to cuFFT or how to call cuFFT correctly (which is actually nontrivial). Let us know how we can help! 🙂

cc: @samaid @aterrel @kkraus14 for vis

leofang avatar Sep 06 '25 01:09 leofang

Hi Leo, thanks for the idea!

I'll be checking out the cuFFT nvmath-python Solution, and looking to use mlx.core.from_dlpack to have zero-copy and want to ensure stream alignment.

Maalvi14 avatar Sep 10 '25 15:09 Maalvi14

@leofang I'm starting work on this enhancement, would like to know if you have any highly recommended platform where I can test my changes with NVIDIA GPUs as I currently have a MacBook. :)

Maalvi14 avatar Sep 10 '25 18:09 Maalvi14

Hi @Maalvi14 Glad to know! I just checked, Google Colab still has free T4 GPUs for access (it is not the default runner, which is CPU-only, so need to do a few clicks to switch). Here's a Colab demo on:

  • how to check the GPU status via nvidia-smi
  • how to install nvmath-python with the correct env (CUDA 12, which is what Colab uses today)
  • how to run a simple nvmath-python code with a CuPy array as input

https://colab.research.google.com/drive/13bfQ2-WhPydh94xDJIdSbGwbCNwPq-2x?usp=sharing

leofang avatar Sep 11 '25 01:09 leofang

brev.nvidia.com has a wide variety of GPU instances, including T4, but I am not sure if we offer free credits, if Colab does not work for you, I can ask around internally.

leofang avatar Sep 11 '25 01:09 leofang