Cryptic XLA error when using JAX 0.4.26
Description
Dear JAX Team,
I am transitioning my code from jax 0.4.25 to jax 0.4.26 because I would like to use it with the latest NGC container and I am running into XLA errors which I cannot quite understand. Unfortunately, I could not reproduce the error with a small piece of code. However, I have observed that the error occurs only when using a batch size greater than one. Here is the error message:
2024-05-23 15:20:04.554952: F external/xla/xla/shape_tree.cc:54] Check failed: result->children_start_id >= 0 (0 vs. -1)
Could you please assist in diagnosing and resolving these issues? Any guidance would be greatly appreciated.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='IgorPC', release='5.15.146.1-microsoft-standard-WSL2', version='#1 SMP Thu Jan 11 04:09:03 UTC 2024', machine='x86_64')
Can you please try with (a) a fresh virtualenv and (b) using jax[cuda12] v0.4.28, which is the current version?
Hi @hawkinsp,
Thank you for the quick response. I followed your instructions but encountered the same error. To provide more context, the error occurs during batched operations with a model written in Equinox. Unfortunately, I have not been able to recreate it with simpler examples, leading me to believe the issue is not on Equinox's side. @patrick-kidger, have you encountered similar issues by any chance?
I've not seen this one before I'm afraid!
Can you share an HLO dump? That might be enough for us to reproduce. Run with XLA_FLAGS=--xla_dump_to=/somewhere, zip up /somewhere and attach it to this bug.
Hmm. I can't reproduce from the HLO dump. I think we'll need a Python-level reproduction.
Actually, never mind, I can reproduce from the HLO.
Hi @hawkinsp,
I am checking in to see if there are any updates on this issue. Any information would be helpful. Thanks!
I filed an internal bug for our XLA compiler folks (b/342589917), and I'm waiting for one of them to take a look.
I have no additional information other than: yes, I can reproduce the problem from the HLO dump.
A fix for this was merged (https://github.com/openxla/xla/commit/e7bd8addde659ed53f292910c12a50f708c5b566). The fix should be in today's jaxlib nightly. Please try it out and let me know if the problem is fixed.
Hi @hawkinsp, Apologies for the late response. I tried following the installation instructions for JAX nightly but I ended up with the following error:
[INFO 06-10 00:27:06] metatopia: Running version 0.0.1 on the GPU.
E0610 00:27:06.855497 17144 cuda_dnn.cc:535] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0610 00:27:06.855663 17144 cuda_dnn.cc:539] Memory usage: 7446986752 bytes free, 8585281536 bytes total.
E0610 00:27:06.855951 17144 cuda_dnn.cc:535] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0610 00:27:06.856039 17144 cuda_dnn.cc:539] Memory usage: 7446986752 bytes free, 8585281536 bytes total.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/igork/projects/metatopia/experiments/topology_optimisation/run.py", line 6, in <module>
import metatopia as mtp
File "/home/igork/projects/metatopia/metatopia/__init__.py", line 21, in <module>
from metatopia import (task_generation, filters, solver, models, utils,
File "/home/igork/projects/metatopia/metatopia/task_generation/__init__.py", line 1, in <module>
from .problems import (mbb_beam, generate_random_2D_problem,
File "/home/igork/projects/metatopia/metatopia/task_generation/problems.py", line 175, in <module>
fixed_key: jr.PRNGKey = jr.PRNGKey(0)):
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/random.py", line 233, in PRNGKey
return _return_prng_keys(True, _key('PRNGKey', seed, impl))
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/random.py", line 195, in _key
return prng.random_seed(seed, impl=impl)
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/prng.py", line 532, in random_seed
seeds_arr = jnp.asarray(np.int64(seeds))
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3153, in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order)
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3078, in array
out_array: Array = lax_internal._convert_element_type(
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 559, in _convert_element_type
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 416, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 909, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
For what it's worth, I've encountrered the same error in jaxlib 0.4.28 when using jax.experimental.io_callback, but haven't managed to get a Python minimal example to reproduce it yet.
~~I also tried to install the nightly, and got the exact same error as @itk22 even when only trying to run a simple command:~~ Doing a clean install fixed the issues with running the nightly.
The nightly then had a nice new traceback, that allowed me to fix the issue.
The actual issue was that I got the shape of the return of jax.experimental.io_callback wrong.