jax icon indicating copy to clipboard operation
jax copied to clipboard

jnp.cumsum is incorrect when sharding over summing axis on GPUs

Open wilson1yan opened this issue 1 year ago • 0 comments

Description

I have this simple script below that tests jnp.cumsum when sharding along the same axis it is summing over. Tested on a machine with 8 40GB A100s.

# test.py
import argparse
import jax
import jax.numpy as jnp
import jax.lax
from jax.sharding import PartitionSpec as PS, Mesh
from jax.experimental.pjit import pjit
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument('--sp', type=int, default=2)
args = parser.parse_args()

def f(mask):
    mask = jax.lax.with_sharding_constraint(mask, PS('dp', 'sp'))
    idxs = jnp.cumsum(mask, axis=-1)
    return idxs
f = pjit(f, in_shardings=PS(), out_shardings=PS())

mesh_shape = (jax.device_count() // args.sp, args.sp)
print(f"Mesh shape: {mesh_shape}")
mesh = Mesh(np.array(jax.devices()).reshape(mesh_shape), ('dp', 'sp'))

B, L = 8, 1024
with mesh:
    mask = np.ones((B, L), dtype=np.int32)
    out = jax.device_get(f(mask))
    expected = np.arange(L, dtype=np.int32)[None].repeat(B, axis=0) + 1
    print('Output:', out)
    print('Expected:', expected)
    assert np.allclose(out, expected)
    print('Passed')

On TPUs, this code works completely fine for different Mesh shardings

On GPUs, the code produces incorrect output when args.sp > 1 (sharding over the summing axis)

Works for no sharding (sp = 1)

> python test.py --sp 1
Mesh shape: (8, 1)
2024-05-23 18:27:20.884869: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow
down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Output: [[   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 ...
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]]
Expected: [[   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 ...
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]]
Passed

Breaks for sharding (sp > 1 = 2, 4, 8 - up to 8 GPUs)

> python test.py --sp 2
Mesh shape: (4, 2)
Output: [[   1    2    4 ... 1278 1279 1024]
 [   1    2    4 ... 1278 1279 1024]
 [   1    2    4 ... 1278 1279 1024]
 ...
 [   1    2    4 ... 1278 1279 1024]
 [   1    2    4 ... 1278 1279 1024]
 [   1    2    4 ... 1278 1279 1024]]
Expected: [[   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 ...
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]
 [   1    2    3 ... 1022 1023 1024]]
Traceback (most recent call last):
  File "/home/wilsonyan/test.py", line 31, in <module>
    assert np.allclose(out, expected)
AssertionError

One other bizarre thing is that if you insert mask = jnp.ones_like(mask) like so:

...
def f(mask):
    mask = jax.lax.with_sharding_constraint(mask, PS('dp', 'sp'))
    mask = jnp.ones_like(mask) # <<< new inserted line
    idxs = jnp.cumsum(mask, axis=-1)
    return idxs
...

The code works again on GPUs.

Any clue what might be going on?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', release='6.8.0-1007-gcp', version='#7-Ubuntu SMP Sat Apr 20 00:58:31 UTC 2024', machine='x86_64')

$ nvidia-smi
Thu May 23 18:28:53 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67                 Driver Version: 550.67         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             59W /  400W |     425MiB /  40960MiB |      3%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   33C    P0             69W /  400W |     425MiB /  40960MiB |      2%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          Off |   00000000:00:06.0 Off |                    0 |
| N/A   31C    P0             61W /  400W |     425MiB /  40960MiB |      2%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA A100-SXM4-40GB          Off |   00000000:00:07.0 Off |                    0 |
| N/A   32C    P0             59W /  400W |     425MiB /  40960MiB |      2%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA A100-SXM4-40GB          Off |   00000000:80:00.0 Off |                    0 |
| N/A   32C    P0             61W /  400W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA A100-SXM4-40GB          Off |   00000000:80:01.0 Off |                    0 |
| N/A   33C    P0             59W /  400W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA A100-SXM4-40GB          Off |   00000000:80:02.0 Off |                    0 |
| N/A   32C    P0             60W /  400W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA A100-SXM4-40GB          Off |   00000000:80:03.0 Off |                    0 |
| N/A   34C    P0             66W /  400W |     425MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    345969      C   python                                        416MiB |
|    1   N/A  N/A    345969      C   python                                        416MiB |
|    2   N/A  N/A    345969      C   python                                        416MiB |
|    3   N/A  N/A    345969      C   python                                        416MiB |
|    4   N/A  N/A    345969      C   python                                        416MiB |
|    5   N/A  N/A    345969      C   python                                        416MiB |
|    6   N/A  N/A    345969      C   python                                        416MiB |
|    7   N/A  N/A    345969      C   python                                        416MiB |
+-----------------------------------------------------------------------------------------+

wilson1yan avatar May 23 '24 18:05 wilson1yan