Different outputs every time fft2 is called
Description
Each time fft2 is called, I get a different result.
from jax.numpy.fft import fft2
x = np.random.randn(7603, 7603)
f1 = fft2(x).block_until_ready()
f2 = fft2(x).block_until_ready()
f3 = fft2(x).block_until_ready()
f4 = fft2(x).block_until_ready()
print(f'first call: {f1[0,0]}, second call: {f2[0,0]}, third call: {f3[0,0]}, fourth call: {f4[0,0]}')
# first call: (0.20946313440799713+0j), second call: (0.3694859743118286+0j), third call: (4124.7333984375+0.002197265625j), fourth call: (3916.027099609375-0.002685546875j)
This behavior does not occur when the size of x is less than 7600 in each of the dimensions. Perhaps some out of bounds indexing is occurring with the large input?
What jax/jaxlib version are you using?
jax v0.3.17, jaxlib v0.3.15
Which accelerator(s) are you using?
GPU (titan 1080ti)
Additional System Info
Linux
What version of CUDA do you have installed?
I'm unable to reproduce this on: a) a V100 cloud machine b) a T4 cloud machine with CUDA 11.7 c) a desktop 1080 GPU with CUDA 11.4
So there must be something different about your environment.
I'm also curious if the value of x is changing between calls.
Outputs of nvcc
$ nvcc --version
bash: nvcc: command not found
I have cuda 11.1 installed in /usr/local/cuda-11.1/ though
With regards to whether or not x is changing. It seems like x is unchanged between calls
from jax.numpy.fft import fft2
x = np.random.randn(7603, 7603)
print(x.sum()) # 1315.9940795048856
f1 = fft2(x).block_until_ready()
print(x.sum()) # 1315.9940795048856
f2 = fft2(x).block_until_ready()
print(x.sum()) # 1315.9940795048856
f3 = fft2(x).block_until_ready()
print(x.sum()) # 1315.9940795048856
f4 = fft2(x).block_until_ready()
print(x.sum()) # 1315.9940795048856
I can reproduce this with CUDA 11.1 but not with CUDA 11.2 or newer. This at least suggests a workaround you can use: update to a newer CUDA release. I'm not sure what's happening but the CUDA release dependence makes me suspect this is NVIDIA's bug.
@sudhakarsingh27 Could you please determine whether this is a known CuFFT bug and if so whether there are workarounds we might use on the JAX side to get correct behavior on CUDA 11.1? The only alternative I can think of is to drop CUDA 11.1 support.
Looking into it
An even simpler (/more obvious) repro
from jax.numpy.fft import fft2
import numpy as np
N = 7603
x = np.ones((N, N))
y = np.zeros((N, N))
y[0,0] = N*N
f1 = fft2(x).block_until_ready()
f2 = fft2(x).block_until_ready()
e1 = np.linalg.norm(f1 - y) / np.linalg.norm(y)
e2 = np.linalg.norm(f2 - y) / np.linalg.norm(y)
print("Errors: ", e1, e2)
With CUDA 11.1
$ LD_PRELOAD=/path/to/cuda/11.1/libcufft.so python3 bug_12298.py
Errors: 8.710503652383524e-07 0.9999999827006407
With CUDA 11.2 and above (slightly different errors depending on the version, but all ~1e-7)
$ LD_PRELOAD=/path/to/cuda/11.6/libcufft.so python3 bug_12298.py
Errors: 6.324330558994378e-07 6.324330558994378e-07
@hawkinsp it does look like it's a resolved bug in CUDA11.2 https://docs.nvidia.com/cuda/archive/11.2.0/cuda-toolkit-release-notes/index.html#cufft-resolved-issues
I assume you mean "Plans with primes larger than 127 in FFT size decomposition or FFT size being a prime number bigger than 4093 do not perform calculations on second and subsequent cufftExecute* calls. Regression was introduced in cuFFT 11.1".
Sounds like we should warn if CUDA 11.1 is used for FFTs then (or just drop support totally).
Yes, that's right.
Where should we document this? (in JAX README perhaps - somewhere in jax:cuda installation steps?)
Edit: Ah, I think you meant adding a warning in the code as well. Is there a similar check for any other lib so we can add this alongside that?
@cheshire We are thinking of making XLA raise an error if CUDA version is bellow 11.2. What do you think of that? We can make a flag to disable this. At least people will know they are in a not supported config and won't expect a flaw less experience.
@hawkinsp What was the latest update on CUDA version policy from JAX?
These days, we raise an error for CUDA versions below 11.8. So I think we can declare this fixed.