[BUG] Provide CUDA implementation of FFT
Describe the bug Attempting an FFT calculation with the CUDA backend throws a runtime error. Not sure if it's a bug or a feature request.
To Reproduce
snapshots-gpu-4000adax1-20gb-tor1# cat reproduce_fft_error.py
#!/usr/bin/env python3
"""
Minimal reproduction of MLX FFT CUDA backend error.
This script demonstrates the error: "FFT has no CUDA implementation."
"""
import mlx.core as mx
def main():
print("Testing MLX FFT with CUDA backend...")
print(f"MLX version: {mx.__version__ if hasattr(mx, '__version__') else 'unknown'}")
print(f"Default device: {mx.default_device()}")
# Create a simple 1D signal
x = mx.array([1.0, 2.0, 3.0, 4.0], dtype=mx.float32)
print(f"Input array: {x}")
try:
# Attempt FFT operation that triggers the CUDA error
result = mx.fft.fft(x)
print(f"FFT result: {result}")
except Exception as e:
print(f"Error occurred: {e}")
print(f"Error type: {type(e).__name__}")
raise
if __name__ == "__main__":
main()
snapshots-gpu-4000adax1-20gb-tor1# uv run python reproduce_fft_error.py
Testing MLX FFT with CUDA backend...
MLX version: 0.29.0
Default device: Device(gpu, 0)
Input array: array([1, 2, 3, 4], dtype=float32)
Error occurred: FFT has no CUDA implementation.
Error type: RuntimeError
Traceback (most recent call last):
File "/root/logfiles/reproduce_fft_error.py", line 30, in <module>
main()
File "/root/logfiles/reproduce_fft_error.py", line 23, in main
print(f"FFT result: {result}")
^^^^^^^^
RuntimeError: FFT has no CUDA implementation.
snapshots-gpu-4000adax1-20gb-tor1#
Expected behavior FFT should work with CUDA backend, or a fallback computation in the CPU should occur on the fly.
Desktop
snapshots-gpu-4000adax1-20gb-tor1# nvidia-smi
Mon Sep 1 14:06:40 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08 Driver Version: 575.57.08 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX 4000 Ada Gene... On | 00000000:01:00.0 Off | Off |
| 30% 33C P8 13W / 130W | 2MiB / 20475MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
snapshots-gpu-4000adax1-20gb-tor1#
snapshots-gpu-4000adax1-20gb-tor1# lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description: Ubuntu 22.04.4 LTS
Release: 22.04
Codename: jammy
snapshots-gpu-4000adax1-20gb-tor1#
snapshots-gpu-4000adax1-20gb-tor1# uv pip show mlx
Name: mlx
Version: 0.29.0
Location: /root/logfiles/.venv/lib/python3.12/site-packages
Requires:
Required-by:
snapshots-gpu-4000adax1-20gb-tor1#
I think wrapping cuFFT is a good solution for this https://developer.nvidia.com/cufft
Currently working on a branch doing the CuFFT solution, I'm using our already implemented cuBLAS wrap currently being used for mlx/backend/cuda/matmul.cpp as reference on how are we handling CUDA libs wrapping
Awesome, thanks!
Hi, CUDA Python tech lead here 👋
cuFFT has an official Python solution which is exposed as part of nvmath-python, home for NVIDIA CPU & GPU math libraries. It should have everything that MLX needs. In fact, we will be replacing CuPy's FFT support using nvmath-python (https://github.com/cupy/cupy/issues/9237) because 1. it's official, 2. it's more performant, 3. the installation footprint is small. This will offer a much easier route for MLX to enable FFT support on NVIDIA GPUs, without worrying about Python bindings to cuFFT or how to call cuFFT correctly (which is actually nontrivial). Let us know how we can help! 🙂
cc: @samaid @aterrel @kkraus14 for vis
Hi Leo, thanks for the idea!
I'll be checking out the cuFFT nvmath-python Solution, and looking to use mlx.core.from_dlpack to have zero-copy and want to ensure stream alignment.
@leofang I'm starting work on this enhancement, would like to know if you have any highly recommended platform where I can test my changes with NVIDIA GPUs as I currently have a MacBook. :)
Hi @Maalvi14 Glad to know! I just checked, Google Colab still has free T4 GPUs for access (it is not the default runner, which is CPU-only, so need to do a few clicks to switch). Here's a Colab demo on:
- how to check the GPU status via
nvidia-smi - how to install
nvmath-pythonwith the correct env (CUDA 12, which is what Colab uses today) - how to run a simple
nvmath-pythoncode with a CuPy array as input
https://colab.research.google.com/drive/13bfQ2-WhPydh94xDJIdSbGwbCNwPq-2x?usp=sharing
brev.nvidia.com has a wide variety of GPU instances, including T4, but I am not sure if we offer free credits, if Colab does not work for you, I can ask around internally.