diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

JIT compiling VirtualBrownianTree.evaluate() hangs indefinitely when using jnp.int32 to pass shape to ShapeDtypeStruct

Open alexander-de-ranitz opened this issue 1 month ago • 5 comments

I can into this rather mysterious issue which caused my code to hang indefinitely, not responding to cntrl + C either. This issue arises in the following (admittedly rather specific) setting:

  • I use jnp.int32s instead of python's ints when passing a shape to ShapeDtypeStruct
  • VirtualBrownianTree.evaluate() is inside a jit'ted (or eqx.filter_jit) function, which is defined and called more than once

The first time the function is called, it compiles and runs perfectly fine. The second time, however, the JIT'ted function begins compiling, but never finishes, see the MWE below. Note that this only occurs when using jnp.int32s to pass a shape, when using regular int this code works fine. The same issue occurs with UnsafeBrownianPath.

MWE:

import diffrax as dfx
import jax
from jax import numpy as jnp
from jax import random as jr

def main(): 
    def generate_path(shape): 
        shape_dtype = jax.ShapeDtypeStruct(
            shape=shape,
            dtype=jnp.float32,
        )

        @jax.jit
        def eval_path(path, t0, t1):
            print("Compiling path evaluation...")
            return path.evaluate(t0=t0, t1=t1)
        
        path = dfx.VirtualBrownianTree(
            t0=0,
            t1=1,
            tol=1e-5,
            shape=shape_dtype,
            key=jr.PRNGKey(0),
            levy_area=dfx.SpaceTimeLevyArea,
        )
        eval_path(path, 0.0, 1.0)
        print("Path evaluated successfully.")

    generate_path(shape=(jnp.int32(2),)) # Works fine
    generate_path(shape=(jnp.int32(2),)) # Hangs indefinitely

if __name__ == "__main__":
    main()

Output:

Compiling path evaluation...
Path evaluated successfully.
Compiling path evaluation...

Perhaps using jnp.int32 to create a ShapeDtypeStruct is simply wrong here, but this is not obvious to me at least an error/warning would be warranted, I suppose. Since scalars are zero-dimensional arrays in JAX, the ShapeDtypeStructs you get with different dtypes for the shape argument are slightly different:

shape_int32 = jax.ShapeDtypeStruct(shape=(jnp.int32(2),), dtype=jnp.float32) 
# shape_int32.shape = (Array(2, dtype=int32),)  
# shape_int32.shape[0] = 2 
# type(shape_int32[0]) = <class 'jaxlib._jax.ArrayImpl'>

shape_int = jax.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32) 
# shape_int.shape = (2,) 
# shape_int.shape[0] = 2  
# type(shape_int.shape[0])= <class 'int'>

The issue can be easily avoided by just using int instead of jnp.int32, but I am very curious why this would make such a difference.

I tested the MWE both in JAX versions 0.7.0 and 0.6.2 with identical results.

alexander-de-ranitz avatar Nov 13 '25 08:11 alexander-de-ranitz

When I'm running your MWE I get the following error, coming from JAX:

ValueError: static arguments should be comparable using __eq__.The following error was raised when 
comparing two objects of types <class 'tuple'> and <class 'tuple'>. The error was:
ValueError: Exception raised while checking equality of metadata fields of pytree. Make sure that metadata 
fields are hashable and have simple equality semantics. (Note: arrays cannot be passed as metadata fields!)

So I think this comes down to using an array instead of an int here.

FWIW I'm using:

jax:    0.8.0
jaxlib: 0.8.0
numpy:  2.3.4
python: 3.13.3 (main, Apr  8 2025, 13:54:08) [Clang 16.0.0 (clang-1600.0.26.6)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node=REDACTED, release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:26 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8112', machine='arm64')

johannahaffner avatar Nov 13 '25 21:11 johannahaffner

It does seem to be fixed by not used the jnp array in shape structure, but with a fresh colab notebook pip installed with diffrax (with jax 0.7.2) it does just hang, so upgrading JAX at least makes the hang into a clear error

import diffrax as dfx
import jax
from jax import numpy as jnp
import equinox as eqx
from jax import random as jr

def main(): 
    def generate_path(shape): 
        shape_dtype = jax.ShapeDtypeStruct(
            shape=shape,
            dtype=jnp.float32,
        )

        #@jax.jit
        @eqx.filter_jit
        def eval_path(path, t0, t1):
            print("Compiling path evaluation...")
            return path.evaluate(t0=t0, t1=t1)
        
        path = dfx.VirtualBrownianTree(
            t0=0,
            t1=1,
            tol=1e-2,
            shape=shape_dtype,
            key=jr.PRNGKey(0),
            levy_area=dfx.SpaceTimeLevyArea,
        )
        print(path)
        path = eval_path(path, 0.0, 1.0)
        print("Path evaluated successfully.", path)
        return path

    _ = jax.block_until_ready(generate_path(shape=(jnp.int32(2),))) # Works fine
    _ = jax.block_until_ready(generate_path(shape=(jnp.int32(2),))) # Hangs indefinitely

if __name__ == "__main__":
    main()

lockwo avatar Nov 13 '25 21:11 lockwo

Thanks for looking into this and good to know that this produces an error in Jax 0.8.0. Out of curiosity, do you have any idea how this could cause a hang in older Jax versions?

alexander-de-ranitz avatar Nov 14 '25 09:11 alexander-de-ranitz

Out of curiosity, do you have any idea how this could cause a hang in older Jax versions?

I would also like to know this, no idea.

johannahaffner avatar Nov 14 '25 10:11 johannahaffner

How weird!

One could potentially use py-spy to attach to the process and observe the stack trace of the stuck program.

patrick-kidger avatar Nov 15 '25 05:11 patrick-kidger