jax
jax copied to clipboard
`roll` incorrect if jit-compiled on GPU with 64-bit mode
Description
For arrays of length 2^n
(tested 4, 8, 16, and 64), roll
gives different results before and after applying jit
.
This behavior appears only on the GPU in 64-bit mode and depends on the CUDA/cuDNN version.
I can't tell if this is a problem with JAX, some underlying library, or possibly a faulty installation.
In 64bit mode:
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
def roll(arr, index):
return jnp.roll(arr, -index)
roll
gives different results before and after jit
:
>>> roll(l, 1)
DeviceArray([1, 2, 3, 4, 5, 6, 7, 0], dtype=int64)
>>> jax.jit(roll)(l, 1)
DeviceArray([7, 0, 1, 2, 3, 4, 5, 6], dtype=int64)
I found this behavior with two different machines (see below) and with cuDNN 8.4.1.50
, CUDA 11.7.0
, and multiple versions of JAX.
The problem disappeared when using cuDNN 8.2.1.32
andCUDA 11.3.1
.
I could also not reproduce this on colab with jax v0.3.25
, jaxlib v0.3.25
on the Tesla T4 GPU (CUDA 11.2
).
What jax/jaxlib version are you using?
Both jax v0.3.25, jaxlib v0.3.25 and jax v0.3.14, jaxlib v0.3.14
Which accelerator(s) are you using?
GPU, 64bit mode
Additional system info
Linux, python 3.10.4
NVIDIA GPU info
Reproduced this on
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.76 Driver Version: 515.76 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA TITAN RTX On | 00000000:5E:00.0 Off | N/A |
| 41% 37C P8 23W / 280W | 1MiB / 24576MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
as well as
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.76 Driver Version: 515.76 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:3B:00.0 Off | N/A |
| 0% 33C P8 8W / 250W | 1MiB / 11264MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce ... On | 00000000:5E:00.0 Off | N/A |
| 0% 25C P8 7W / 250W | 1MiB / 11264MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA GeForce ... On | 00000000:B1:00.0 Off | N/A |
| 0% 24C P8 14W / 250W | 1MiB / 11264MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA GeForce ... On | 00000000:D9:00.0 Off | N/A |
| 0% 26C P8 8W / 250W | 1MiB / 11264MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
Smaller repro:
import jax, jax.lax as lax
jax.config.update("jax_enable_x64", True)
def f(x):
return lax.rem(-x, 8)
print(f(1))
print(jax.jit(f)(1))
which prints
-1
1
when it should print the same thing both times.
XLA ends up optimizing this to:
def f(a):
x = -a
y = lax.select(x < 0, -x, x)
z = lax.bitwise_and(y, 0x7)
return lax.select(x < 0, -z, z)
and that repro works also.
This appears to be a miscompilation of some sort in ptxas
which is provided by NVIDIA. As a workaround, you should downgrade your CUDA installation to a copy without the bug.
In both cases, JAX generates the following PTX, which looks correct:
//
// Generated by LLVM NVPTX Back-End
//
.version 7.0
.target sm_80
.address_size 64
// .globl fusion
.visible .entry fusion(
.param .u64 fusion_param_0,
.param .u64 fusion_param_1
)
.reqntid 1, 1, 1
{
.reg .pred %p<2>;
.reg .b64 %rd<11>;
ld.param.u64 %rd1, [fusion_param_0];
ld.param.u64 %rd2, [fusion_param_1];
cvta.to.global.u64 %rd3, %rd2;
cvta.to.global.u64 %rd4, %rd1;
ld.global.nc.u64 %rd5, [%rd4];
neg.s64 %rd6, %rd5;
setp.lt.s64 %p1, %rd6, 0;
abs.s64 %rd7, %rd5;
and.b64 %rd8, %rd7, 7;
neg.s64 %rd9, %rd8;
selp.b64 %rd10, %rd9, %rd8, %p1;
st.global.u64 [%rd3], %rd10;
ret;
}
With ptxas
from CUDA 11.8 we get a wrong output, but with ptxas
from CUDA 11.6 we get a correct output.
@nouiz @yhtang for viz
Filed NVIDIA partners bug 3872915
Hi @mathisgerdes
Looks like this issue has been resolved in later versions of JAX. I executed the mentioned code on colab (GPU T4) with cuda 12.3 and cuDNN 8.9.7 and JAX versions 0.4.23 and 0.4.25. The roll
produces the same output with and without jit
.
I also verified with cuda 11.8 and cuDNN 8.9.6 and JAX versions 0.4.23 and 0.4.25, the output is same for both with and without jit-compilation.
from jax import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
def roll(arr, index):
return jnp.roll(arr, -index)
Without JIT
-Compilation:
l = jnp.arange(8)
print(l)
roll(l, 1)
Output:
[0 1 2 3 4 5 6 7]
Array([1, 2, 3, 4, 5, 6, 7, 0], dtype=int64)
With JIT
-Compilation:
jax.jit(roll)(l, 1)
Output:
Array([1, 2, 3, 4, 5, 6, 7, 0], dtype=int64)
Kindly find the gist for reference.
Thank you
Hi @mathisgerdes
Please feel free to close the issue, if it is resolved.
Thank you.