equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Jax 0.4.27: ValueError: safe_map() argument 2 is shorter than argument 1

Open GaetanLepage opened this issue 2 months ago • 7 comments

Since jax 0.4.27, several tests fail with:

args = (_ClosureConvert(
  jaxpr={ lambda ; a:f32[47] b:f32[] c:i32[] d:bool[] e:bool[] f:i32[] g:f32[] h:f32[47]
    i:bool[...t 0x7fff4c435a90>,
  _makes_false_steps=False
), Traced<ShapedArray(float32[47])>with<DynamicJaxprTrace(level=5/1)>))))
kwds = {}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           ValueError: safe_map() argument 2 is shorter than argument 1
E           --------------------
E           For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

GaetanLepage avatar May 07 '24 21:05 GaetanLepage