jax icon indicating copy to clipboard operation
jax copied to clipboard

Different outputs every time fft2 is called

Open KyleLuther opened this issue 3 years ago • 11 comments

Description

Each time fft2 is called, I get a different result.

from jax.numpy.fft import fft2

x = np.random.randn(7603, 7603)

f1 = fft2(x).block_until_ready()
f2 = fft2(x).block_until_ready()
f3 = fft2(x).block_until_ready()
f4 = fft2(x).block_until_ready()

print(f'first call: {f1[0,0]}, second call: {f2[0,0]}, third call: {f3[0,0]}, fourth call: {f4[0,0]}')
# first call: (0.20946313440799713+0j), second call: (0.3694859743118286+0j), third call: (4124.7333984375+0.002197265625j), fourth call: (3916.027099609375-0.002685546875j)

This behavior does not occur when the size of x is less than 7600 in each of the dimensions. Perhaps some out of bounds indexing is occurring with the large input?

What jax/jaxlib version are you using?

jax v0.3.17, jaxlib v0.3.15

Which accelerator(s) are you using?

GPU (titan 1080ti)

Additional System Info

Linux

KyleLuther avatar Sep 08 '22 21:09 KyleLuther

What version of CUDA do you have installed?

hawkinsp avatar Sep 09 '22 13:09 hawkinsp

I'm unable to reproduce this on: a) a V100 cloud machine b) a T4 cloud machine with CUDA 11.7 c) a desktop 1080 GPU with CUDA 11.4

So there must be something different about your environment.

I'm also curious if the value of x is changing between calls.

hawkinsp avatar Sep 09 '22 13:09 hawkinsp

Outputs of nvcc

$ nvcc --version
 bash: nvcc: command not found

I have cuda 11.1 installed in /usr/local/cuda-11.1/ though

KyleLuther avatar Sep 09 '22 17:09 KyleLuther

With regards to whether or not x is changing. It seems like x is unchanged between calls


from jax.numpy.fft import fft2

x = np.random.randn(7603, 7603)
print(x.sum()) # 1315.9940795048856
f1 = fft2(x).block_until_ready()
print(x.sum()) # 1315.9940795048856
f2 = fft2(x).block_until_ready()
print(x.sum()) # 1315.9940795048856
f3 = fft2(x).block_until_ready()
print(x.sum()) # 1315.9940795048856
f4 = fft2(x).block_until_ready()
print(x.sum()) # 1315.9940795048856

KyleLuther avatar Sep 09 '22 17:09 KyleLuther

I can reproduce this with CUDA 11.1 but not with CUDA 11.2 or newer. This at least suggests a workaround you can use: update to a newer CUDA release. I'm not sure what's happening but the CUDA release dependence makes me suspect this is NVIDIA's bug.

hawkinsp avatar Sep 13 '22 13:09 hawkinsp

@sudhakarsingh27 Could you please determine whether this is a known CuFFT bug and if so whether there are workarounds we might use on the JAX side to get correct behavior on CUDA 11.1? The only alternative I can think of is to drop CUDA 11.1 support.

hawkinsp avatar Sep 13 '22 13:09 hawkinsp

Looking into it

sudhakarsingh27 avatar Sep 13 '22 19:09 sudhakarsingh27

An even simpler (/more obvious) repro

from jax.numpy.fft import fft2
import numpy as np

N = 7603
x = np.ones((N, N))

y = np.zeros((N, N))
y[0,0] = N*N

f1 = fft2(x).block_until_ready()
f2 = fft2(x).block_until_ready()

e1 = np.linalg.norm(f1 - y) / np.linalg.norm(y)
e2 = np.linalg.norm(f2 - y) / np.linalg.norm(y)

print("Errors: ", e1, e2)

With CUDA 11.1

$ LD_PRELOAD=/path/to/cuda/11.1/libcufft.so python3 bug_12298.py
Errors:  8.710503652383524e-07 0.9999999827006407

With CUDA 11.2 and above (slightly different errors depending on the version, but all ~1e-7)

$ LD_PRELOAD=/path/to/cuda/11.6/libcufft.so python3 bug_12298.py
Errors:  6.324330558994378e-07 6.324330558994378e-07

nvlcambier avatar Sep 13 '22 20:09 nvlcambier

@hawkinsp it does look like it's a resolved bug in CUDA11.2 https://docs.nvidia.com/cuda/archive/11.2.0/cuda-toolkit-release-notes/index.html#cufft-resolved-issues

sudhakarsingh27 avatar Sep 13 '22 22:09 sudhakarsingh27

I assume you mean "Plans with primes larger than 127 in FFT size decomposition or FFT size being a prime number bigger than 4093 do not perform calculations on second and subsequent cufftExecute* calls. Regression was introduced in cuFFT 11.1".

Sounds like we should warn if CUDA 11.1 is used for FFTs then (or just drop support totally).

hawkinsp avatar Sep 13 '22 22:09 hawkinsp

Yes, that's right.

Where should we document this? (in JAX README perhaps - somewhere in jax:cuda installation steps?)

Edit: Ah, I think you meant adding a warning in the code as well. Is there a similar check for any other lib so we can add this alongside that?

sudhakarsingh27 avatar Sep 13 '22 22:09 sudhakarsingh27

@cheshire We are thinking of making XLA raise an error if CUDA version is bellow 11.2. What do you think of that? We can make a flag to disable this. At least people will know they are in a not supported config and won't expect a flaw less experience.

nouiz avatar Oct 12 '22 19:10 nouiz

@hawkinsp What was the latest update on CUDA version policy from JAX?

cheshire avatar Oct 12 '22 20:10 cheshire

These days, we raise an error for CUDA versions below 11.8. So I think we can declare this fixed.

hawkinsp avatar Nov 17 '23 21:11 hawkinsp