jax icon indicating copy to clipboard operation
jax copied to clipboard

`roll` incorrect if jit-compiled on GPU with 64-bit mode

Open mathisgerdes opened this issue 2 years ago • 5 comments

Description

For arrays of length 2^n (tested 4, 8, 16, and 64), roll gives different results before and after applying jit. This behavior appears only on the GPU in 64-bit mode and depends on the CUDA/cuDNN version. I can't tell if this is a problem with JAX, some underlying library, or possibly a faulty installation.

In 64bit mode:

from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp

def roll(arr, index):
    return jnp.roll(arr, -index)

roll gives different results before and after jit:

>>> roll(l, 1)
DeviceArray([1, 2, 3, 4, 5, 6, 7, 0], dtype=int64)
>>> jax.jit(roll)(l, 1)
DeviceArray([7, 0, 1, 2, 3, 4, 5, 6], dtype=int64)

I found this behavior with two different machines (see below) and with cuDNN 8.4.1.50, CUDA 11.7.0, and multiple versions of JAX.

The problem disappeared when using cuDNN 8.2.1.32 andCUDA 11.3.1. I could also not reproduce this on colab with jax v0.3.25, jaxlib v0.3.25 on the Tesla T4 GPU (CUDA 11.2).

What jax/jaxlib version are you using?

Both jax v0.3.25, jaxlib v0.3.25 and jax v0.3.14, jaxlib v0.3.14

Which accelerator(s) are you using?

GPU, 64bit mode

Additional system info

Linux, python 3.10.4

NVIDIA GPU info

Reproduced this on

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.76       Driver Version: 515.76       CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA TITAN RTX    On   | 00000000:5E:00.0 Off |                  N/A |
| 41%   37C    P8    23W / 280W |      1MiB / 24576MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

as well as

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.76       Driver Version: 515.76       CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:3B:00.0 Off |                  N/A |
|  0%   33C    P8     8W / 250W |      1MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:5E:00.0 Off |                  N/A |
|  0%   25C    P8     7W / 250W |      1MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA GeForce ...  On   | 00000000:B1:00.0 Off |                  N/A |
|  0%   24C    P8    14W / 250W |      1MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA GeForce ...  On   | 00000000:D9:00.0 Off |                  N/A |
|  0%   26C    P8     8W / 250W |      1MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

mathisgerdes avatar Nov 16 '22 07:11 mathisgerdes

Smaller repro:

import jax, jax.lax as lax
jax.config.update("jax_enable_x64", True)

def f(x):
  return lax.rem(-x, 8)

print(f(1))
print(jax.jit(f)(1))

which prints

-1
1

when it should print the same thing both times.

XLA ends up optimizing this to:

def f(a):
  x = -a
  y = lax.select(x < 0, -x, x)
  z = lax.bitwise_and(y, 0x7)
  return lax.select(x < 0, -z, z)

and that repro works also.

hawkinsp avatar Nov 16 '22 19:11 hawkinsp

This appears to be a miscompilation of some sort in ptxas which is provided by NVIDIA. As a workaround, you should downgrade your CUDA installation to a copy without the bug.

In both cases, JAX generates the following PTX, which looks correct:

//
// Generated by LLVM NVPTX Back-End
//

.version 7.0
.target sm_80
.address_size 64

        // .globl       fusion

.visible .entry fusion(
        .param .u64 fusion_param_0,
        .param .u64 fusion_param_1
)
.reqntid 1, 1, 1
{
        .reg .pred      %p<2>;
        .reg .b64       %rd<11>;

        ld.param.u64    %rd1, [fusion_param_0];
        ld.param.u64    %rd2, [fusion_param_1];
        cvta.to.global.u64      %rd3, %rd2;
        cvta.to.global.u64      %rd4, %rd1;
        ld.global.nc.u64        %rd5, [%rd4];
        neg.s64         %rd6, %rd5;
        setp.lt.s64     %p1, %rd6, 0;
        abs.s64         %rd7, %rd5;
        and.b64         %rd8, %rd7, 7;
        neg.s64         %rd9, %rd8;
        selp.b64        %rd10, %rd9, %rd8, %p1;
        st.global.u64   [%rd3], %rd10;
        ret;

}

With ptxas from CUDA 11.8 we get a wrong output, but with ptxas from CUDA 11.6 we get a correct output.

hawkinsp avatar Nov 16 '22 20:11 hawkinsp

@nouiz @yhtang for viz

mjsML avatar Nov 16 '22 20:11 mjsML

Filed NVIDIA partners bug 3872915

hawkinsp avatar Nov 16 '22 21:11 hawkinsp

Hi @mathisgerdes

Looks like this issue has been resolved in later versions of JAX. I executed the mentioned code on colab (GPU T4) with cuda 12.3 and cuDNN 8.9.7 and JAX versions 0.4.23 and 0.4.25. The roll produces the same output with and without jit.

I also verified with cuda 11.8 and cuDNN 8.9.6 and JAX versions 0.4.23 and 0.4.25, the output is same for both with and without jit-compilation.

from jax import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp

def roll(arr, index):
    return jnp.roll(arr, -index)

Without JIT-Compilation:

l = jnp.arange(8)
print(l)
roll(l, 1)

Output:

[0 1 2 3 4 5 6 7]
Array([1, 2, 3, 4, 5, 6, 7, 0], dtype=int64)

With JIT-Compilation:

jax.jit(roll)(l, 1)

Output:

Array([1, 2, 3, 4, 5, 6, 7, 0], dtype=int64)

Kindly find the gist for reference.

Thank you

rajasekharporeddy avatar Feb 27 '24 10:02 rajasekharporeddy

Hi @mathisgerdes

Please feel free to close the issue, if it is resolved.

Thank you.

rajasekharporeddy avatar Feb 28 '24 06:02 rajasekharporeddy