jax icon indicating copy to clipboard operation
jax copied to clipboard

Cryptic XLA error when using JAX 0.4.26

Open itk22 opened this issue 1 year ago • 4 comments

Description

Dear JAX Team,

I am transitioning my code from jax 0.4.25 to jax 0.4.26 because I would like to use it with the latest NGC container and I am running into XLA errors which I cannot quite understand. Unfortunately, I could not reproduce the error with a small piece of code. However, I have observed that the error occurs only when using a batch size greater than one. Here is the error message:

2024-05-23 15:20:04.554952: F external/xla/xla/shape_tree.cc:54] Check failed: result->children_start_id >= 0 (0 vs. -1)

Could you please assist in diagnosing and resolving these issues? Any guidance would be greatly appreciated.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='IgorPC', release='5.15.146.1-microsoft-standard-WSL2', version='#1 SMP Thu Jan 11 04:09:03 UTC 2024', machine='x86_64')

itk22 avatar May 23 '24 13:05 itk22

Can you please try with (a) a fresh virtualenv and (b) using jax[cuda12] v0.4.28, which is the current version?

hawkinsp avatar May 23 '24 19:05 hawkinsp

Hi @hawkinsp,

Thank you for the quick response. I followed your instructions but encountered the same error. To provide more context, the error occurs during batched operations with a model written in Equinox. Unfortunately, I have not been able to recreate it with simpler examples, leading me to believe the issue is not on Equinox's side. @patrick-kidger, have you encountered similar issues by any chance?

itk22 avatar May 24 '24 08:05 itk22

I've not seen this one before I'm afraid!

patrick-kidger avatar May 24 '24 10:05 patrick-kidger

Can you share an HLO dump? That might be enough for us to reproduce. Run with XLA_FLAGS=--xla_dump_to=/somewhere, zip up /somewhere and attach it to this bug.

hawkinsp avatar May 24 '24 12:05 hawkinsp

@hawkinsp, here is the HLO dump for the run with 0.4.28

dump.zip

itk22 avatar May 24 '24 13:05 itk22

Hmm. I can't reproduce from the HLO dump. I think we'll need a Python-level reproduction.

hawkinsp avatar May 24 '24 14:05 hawkinsp

Actually, never mind, I can reproduce from the HLO.

hawkinsp avatar May 24 '24 14:05 hawkinsp

Hi @hawkinsp,

I am checking in to see if there are any updates on this issue. Any information would be helpful. Thanks!

itk22 avatar May 29 '24 12:05 itk22

I filed an internal bug for our XLA compiler folks (b/342589917), and I'm waiting for one of them to take a look.

I have no additional information other than: yes, I can reproduce the problem from the HLO dump.

hawkinsp avatar May 29 '24 13:05 hawkinsp

A fix for this was merged (https://github.com/openxla/xla/commit/e7bd8addde659ed53f292910c12a50f708c5b566). The fix should be in today's jaxlib nightly. Please try it out and let me know if the problem is fixed.

hawkinsp avatar Jun 04 '24 14:06 hawkinsp

Hi @hawkinsp, Apologies for the late response. I tried following the installation instructions for JAX nightly but I ended up with the following error:

[INFO 06-10 00:27:06] metatopia: Running version 0.0.1 on the GPU.
E0610 00:27:06.855497   17144 cuda_dnn.cc:535] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0610 00:27:06.855663   17144 cuda_dnn.cc:539] Memory usage: 7446986752 bytes free, 8585281536 bytes total.
E0610 00:27:06.855951   17144 cuda_dnn.cc:535] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0610 00:27:06.856039   17144 cuda_dnn.cc:539] Memory usage: 7446986752 bytes free, 8585281536 bytes total.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/igork/projects/metatopia/experiments/topology_optimisation/run.py", line 6, in <module>
    import metatopia as mtp
  File "/home/igork/projects/metatopia/metatopia/__init__.py", line 21, in <module>
    from metatopia import (task_generation, filters, solver, models, utils,
  File "/home/igork/projects/metatopia/metatopia/task_generation/__init__.py", line 1, in <module>
    from .problems import (mbb_beam, generate_random_2D_problem,
  File "/home/igork/projects/metatopia/metatopia/task_generation/problems.py", line 175, in <module>
    fixed_key: jr.PRNGKey = jr.PRNGKey(0)):
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/random.py", line 233, in PRNGKey
    return _return_prng_keys(True, _key('PRNGKey', seed, impl))
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/random.py", line 195, in _key
    return prng.random_seed(seed, impl=impl)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/prng.py", line 532, in random_seed
    seeds_arr = jnp.asarray(np.int64(seeds))
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3153, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3078, in array
    out_array: Array = lax_internal._convert_element_type(
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 559, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 416, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/core.py", line 909, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/igork/tools/miniconda3/envs/metatopia/lib/python3.10/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

itk22 avatar Jun 09 '24 22:06 itk22

For what it's worth, I've encountrered the same error in jaxlib 0.4.28 when using jax.experimental.io_callback, but haven't managed to get a Python minimal example to reproduce it yet.

~~I also tried to install the nightly, and got the exact same error as @itk22 even when only trying to run a simple command:~~ Doing a clean install fixed the issues with running the nightly.

The nightly then had a nice new traceback, that allowed me to fix the issue. The actual issue was that I got the shape of the return of jax.experimental.io_callback wrong.

JohannesAck avatar Jun 11 '24 08:06 JohannesAck