equinox
equinox copied to clipboard
Jax 0.4.27: ValueError: safe_map() argument 2 is shorter than argument 1
Since jax 0.4.27, several tests fail with:
args = (_ClosureConvert(
jaxpr={ lambda ; a:f32[47] b:f32[] c:i32[] d:bool[] e:bool[] f:i32[] g:f32[] h:f32[47]
i:bool[...t 0x7fff4c435a90>,
_makes_false_steps=False
), Traced<ShapedArray(float32[47])>with<DynamicJaxprTrace(level=5/1)>))))
kwds = {}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E ValueError: safe_map() argument 2 is shorter than argument 1
E --------------------
E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.