Iurii Kemaev

Results 2 issues of Iurii Kemaev

### Description `jax.closure_convert` applied to a vjp closure triggers an empty `AssertionError` as if there is a mismatch in the input trees. Code for reproduction: ``` import jax import jax.numpy...

better_errors

### Description The following code ``` import jax import jax.ad_checkpoint from jax import numpy as jnp @jax.jit def apply(params, x): def step(y, i): y = jnp.sin(y) y = jax.ad_checkpoint.checkpoint_name(y, 'save_remat')...

bug