JIT compiling VirtualBrownianTree.evaluate() hangs indefinitely when using jnp.int32 to pass shape to ShapeDtypeStruct
I can into this rather mysterious issue which caused my code to hang indefinitely, not responding to cntrl + C either. This issue arises in the following (admittedly rather specific) setting:
- I use
jnp.int32s instead of python'sints when passing a shape toShapeDtypeStruct - VirtualBrownianTree.evaluate() is inside a
jit'ted (oreqx.filter_jit) function, which is defined and called more than once
The first time the function is called, it compiles and runs perfectly fine. The second time, however, the JIT'ted function begins compiling, but never finishes, see the MWE below. Note that this only occurs when using jnp.int32s to pass a shape, when using regular int this code works fine. The same issue occurs with UnsafeBrownianPath.
MWE:
import diffrax as dfx
import jax
from jax import numpy as jnp
from jax import random as jr
def main():
def generate_path(shape):
shape_dtype = jax.ShapeDtypeStruct(
shape=shape,
dtype=jnp.float32,
)
@jax.jit
def eval_path(path, t0, t1):
print("Compiling path evaluation...")
return path.evaluate(t0=t0, t1=t1)
path = dfx.VirtualBrownianTree(
t0=0,
t1=1,
tol=1e-5,
shape=shape_dtype,
key=jr.PRNGKey(0),
levy_area=dfx.SpaceTimeLevyArea,
)
eval_path(path, 0.0, 1.0)
print("Path evaluated successfully.")
generate_path(shape=(jnp.int32(2),)) # Works fine
generate_path(shape=(jnp.int32(2),)) # Hangs indefinitely
if __name__ == "__main__":
main()
Output:
Compiling path evaluation...
Path evaluated successfully.
Compiling path evaluation...
Perhaps using jnp.int32 to create a ShapeDtypeStruct is simply wrong here, but this is not obvious to me at least an error/warning would be warranted, I suppose. Since scalars are zero-dimensional arrays in JAX, the ShapeDtypeStructs you get with different dtypes for the shape argument are slightly different:
shape_int32 = jax.ShapeDtypeStruct(shape=(jnp.int32(2),), dtype=jnp.float32)
# shape_int32.shape = (Array(2, dtype=int32),)
# shape_int32.shape[0] = 2
# type(shape_int32[0]) = <class 'jaxlib._jax.ArrayImpl'>
shape_int = jax.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32)
# shape_int.shape = (2,)
# shape_int.shape[0] = 2
# type(shape_int.shape[0])= <class 'int'>
The issue can be easily avoided by just using int instead of jnp.int32, but I am very curious why this would make such a difference.
I tested the MWE both in JAX versions 0.7.0 and 0.6.2 with identical results.
When I'm running your MWE I get the following error, coming from JAX:
ValueError: static arguments should be comparable using __eq__.The following error was raised when
comparing two objects of types <class 'tuple'> and <class 'tuple'>. The error was:
ValueError: Exception raised while checking equality of metadata fields of pytree. Make sure that metadata
fields are hashable and have simple equality semantics. (Note: arrays cannot be passed as metadata fields!)
So I think this comes down to using an array instead of an int here.
FWIW I'm using:
jax: 0.8.0
jaxlib: 0.8.0
numpy: 2.3.4
python: 3.13.3 (main, Apr 8 2025, 13:54:08) [Clang 16.0.0 (clang-1600.0.26.6)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node=REDACTED, release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:26 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8112', machine='arm64')
It does seem to be fixed by not used the jnp array in shape structure, but with a fresh colab notebook pip installed with diffrax (with jax 0.7.2) it does just hang, so upgrading JAX at least makes the hang into a clear error
import diffrax as dfx
import jax
from jax import numpy as jnp
import equinox as eqx
from jax import random as jr
def main():
def generate_path(shape):
shape_dtype = jax.ShapeDtypeStruct(
shape=shape,
dtype=jnp.float32,
)
#@jax.jit
@eqx.filter_jit
def eval_path(path, t0, t1):
print("Compiling path evaluation...")
return path.evaluate(t0=t0, t1=t1)
path = dfx.VirtualBrownianTree(
t0=0,
t1=1,
tol=1e-2,
shape=shape_dtype,
key=jr.PRNGKey(0),
levy_area=dfx.SpaceTimeLevyArea,
)
print(path)
path = eval_path(path, 0.0, 1.0)
print("Path evaluated successfully.", path)
return path
_ = jax.block_until_ready(generate_path(shape=(jnp.int32(2),))) # Works fine
_ = jax.block_until_ready(generate_path(shape=(jnp.int32(2),))) # Hangs indefinitely
if __name__ == "__main__":
main()
Thanks for looking into this and good to know that this produces an error in Jax 0.8.0. Out of curiosity, do you have any idea how this could cause a hang in older Jax versions?
Out of curiosity, do you have any idea how this could cause a hang in older Jax versions?
I would also like to know this, no idea.
How weird!
One could potentially use py-spy to attach to the process and observe the stack trace of the stuck program.