jax icon indicating copy to clipboard operation
jax copied to clipboard

CUDA `XlaRuntimeError` with MPI on `jax==0.4.31`

Open MasterSkepticista opened this issue 1 year ago • 5 comments

Description

Hi,

jax.jit on a function seems to fail when running in an OpenMPI environment. An MWE is shown below:

# error.py
# Run as: mpirun -n 8 python error.py

import os
from absl import logging
import jax, jax.numpy as jnp

logging.set_verbosity("info")
os.environ["no_proxy"] = "x.x.x.x"  # Internal use.
jax.distributed.initialize()

print("Hello from process %d holding %d device(s)" % (jax.process_index(), jax.local_device_count()))

def dot_product_attention(
    query: jnp.ndarray,
    key: jnp.ndarray,
    value: jnp.ndarray,
    *,
    dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
  depth = query.shape[-1]
  query = query / jnp.sqrt(depth).astype(dtype)
  attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key)
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
  return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)

x = jnp.ones((1, 512, 8, 32), dtype=jnp.bfloat16)
f = lambda x: dot_product_attention(x, x, x)

print(jax.jit(f)(x))

The error can be on select processes (in which case I see the output tensor) or all processes (it hangs/exits). I can confirm this error does not appear in jax==0.4.30.

System info (python version, jaxlib version, accelerator, etc.)

Error log
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
Hello from process 3 holding 1 device(s)
Hello from process 5 holding 1 device(s)
Hello from process 1 holding 1 device(s)
Hello from process 7 holding 1 device(s)
Hello from process 0 holding 1 device(s)
Hello from process 4 holding 1 device(s)
Hello from process 6 holding 1 device(s)
Hello from process 2 holding 1 device(s)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_1 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[53590,1],2]
  Exit code:    1
--------------------------------------------------------------------------

System info:

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ubuntu', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May  7 09:00:52 UTC 2', machine='x86_64')

Truncated nvidia-smi info: 
NVIDIA-SMI 555.42.06              
Driver Version: 555.42.06      
CUDA Version: 12.5
GPU: RTX A6000

MasterSkepticista avatar Aug 12 '24 06:08 MasterSkepticista

@MasterSkepticista the error is related with getting cuda:gemm_fusion_autotuning_results on shards and maybe related to https://github.com/openxla/xla/pull/13108 (cc @sergachev). To disable the autotuning and to make your MWE work, you could try to run it with:

XLA_FLAGS=--xla_gpu_shard_autotuning=false  mpirun -n 8 python error.py

Let me know if this workaround helps

vfdev-5 avatar Aug 14 '24 12:08 vfdev-5

https://github.com/openxla/xla/pull/13108 was reverted.

--xla_gpu_shard_autotuning=false disables sharding of autotuning, not the autotuning itself.

sergachev avatar Aug 14 '24 14:08 sergachev

I can reproduce with jax==0.4.31 and --xla_gpu_shard_autotuning=false helps - looks like https://github.com/openxla/xla/pull/13108 got into this JAX release before it got reverted. Thank you for cc'ing me, I'll investigate why does it fail.

sergachev avatar Aug 14 '24 14:08 sergachev

@vfdev-5 Your suggestion worked. @sergachev I observed that JAX was built against https://github.com/openxla/xla/commit/95e3eea8d2aebd55160ed4185a38345ae98ab500, which was before the revert

MasterSkepticista avatar Aug 14 '24 14:08 MasterSkepticista

I sent a fix to XLA which makes the reproducer from this bug work. Independent of that, sharded autotuning got enabled yesterday again and it will likely get into the next JAX release.

sergachev avatar Aug 16 '24 17:08 sergachev