CUDA `XlaRuntimeError` with MPI on `jax==0.4.31`
Description
Hi,
jax.jit on a function seems to fail when running in an OpenMPI environment. An MWE is shown below:
# error.py
# Run as: mpirun -n 8 python error.py
import os
from absl import logging
import jax, jax.numpy as jnp
logging.set_verbosity("info")
os.environ["no_proxy"] = "x.x.x.x" # Internal use.
jax.distributed.initialize()
print("Hello from process %d holding %d device(s)" % (jax.process_index(), jax.local_device_count()))
def dot_product_attention(
query: jnp.ndarray,
key: jnp.ndarray,
value: jnp.ndarray,
*,
dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key)
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value)
x = jnp.ones((1, 512, 8, 32), dtype=jnp.bfloat16)
f = lambda x: dot_product_attention(x, x, x)
print(jax.jit(f)(x))
The error can be on select processes (in which case I see the output tensor) or all processes (it hangs/exits). I can confirm this error does not appear in jax==0.4.30.
System info (python version, jaxlib version, accelerator, etc.)
Error log
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
JAX detected proxy variable(s) in the environment as distributed setup: no_proxy https_proxy HTTPS_PROXY HTTP_PROXY http_proxy. On some systems, this may cause a hang of distributed.initialize and you may need to unset these ENV variable(s)
Hello from process 3 holding 1 device(s)
Hello from process 5 holding 1 device(s)
Hello from process 1 holding 1 device(s)
Hello from process 7 holding 1 device(s)
Hello from process 0 holding 1 device(s)
Hello from process 4 holding 1 device(s)
Hello from process 6 holding 1 device(s)
Hello from process 2 holding 1 device(s)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_1 and duration: -1ms
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/karan/workspace/jax_gpt2/error.py", line 14, in <module>
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: DEADLINE_EXCEEDED: GetKeyValue() timed out with key: cuda:gemm_fusion_autotuning_results_5_0 and duration: -1ms
--------------------------------------------------------------------------
Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:
Process name: [[53590,1],2]
Exit code: 1
--------------------------------------------------------------------------
System info:
jax: 0.4.31
jaxlib: 0.4.31
numpy: 1.26.4
python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ubuntu', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May 7 09:00:52 UTC 2', machine='x86_64')
Truncated nvidia-smi info:
NVIDIA-SMI 555.42.06
Driver Version: 555.42.06
CUDA Version: 12.5
GPU: RTX A6000
@MasterSkepticista the error is related with getting cuda:gemm_fusion_autotuning_results on shards and maybe related to https://github.com/openxla/xla/pull/13108 (cc @sergachev). To disable the autotuning and to make your MWE work, you could try to run it with:
XLA_FLAGS=--xla_gpu_shard_autotuning=false mpirun -n 8 python error.py
Let me know if this workaround helps
https://github.com/openxla/xla/pull/13108 was reverted.
--xla_gpu_shard_autotuning=false disables sharding of autotuning, not the autotuning itself.
I can reproduce with jax==0.4.31 and --xla_gpu_shard_autotuning=false helps - looks like https://github.com/openxla/xla/pull/13108 got into this JAX release before it got reverted. Thank you for cc'ing me, I'll investigate why does it fail.
@vfdev-5 Your suggestion worked. @sergachev I observed that JAX was built against https://github.com/openxla/xla/commit/95e3eea8d2aebd55160ed4185a38345ae98ab500, which was before the revert
I sent a fix to XLA which makes the reproducer from this bug work. Independent of that, sharded autotuning got enabled yesterday again and it will likely get into the next JAX release.