jax
jax copied to clipboard
`closure_convert` doesn't work on VJPs
Description
jax.closure_convert
applied to a vjp closure triggers an empty AssertionError
as if there is a mismatch in the input trees. Code for reproduction:
import jax
import jax.numpy as jnp
jax.config.update('jax_traceback_filtering', 'off')
x = jnp.array(2.0)
z = jnp.array(3.0)
@jax.jit
def call_pure_vjp(x, z):
def f(x, z):
return z * x
y, vjp = jax.vjp(f, x, z)
vjp(x) # runs OK
# note: vjp_aux_args is _not_ empty under `@jax.jit`
vjp_pure, *vjp_aux_args = jax.closure_convert(vjp, x)
g = vjp_pure(x, *vjp_aux_args)
return y, g
call_pure_vjp(x, z)
Outputs:
google3/third_party/py/jax/_src/custom_derivatives.py in converted_fun(*args_hconsts)
1177 consts = merge(closure_consts, hoisted_consts)
1178 all_args, in_tree2 = tree_flatten(tuple(args))
-> 1179 assert in_tree == in_tree2
1180 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
1181 return tree_unflatten(out_tree, out_flat)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.3
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (f142f8afe21bceb00fb495468aa0b5043e98c419)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.10.0-smp-1103.32.0.0', version='#1 [v5.10.0-1103.32.0.0] SMP @1721941885', machine='x86_64')