jax icon indicating copy to clipboard operation
jax copied to clipboard

`closure_convert` doesn't work on VJPs

Open hbq1 opened this issue 5 months ago • 1 comments

Description

jax.closure_convert applied to a vjp closure triggers an empty AssertionError as if there is a mismatch in the input trees. Code for reproduction:

import jax
import jax.numpy as jnp

jax.config.update('jax_traceback_filtering', 'off')

x = jnp.array(2.0)
z = jnp.array(3.0)

@jax.jit
def call_pure_vjp(x, z):

  def f(x, z):
    return z * x

  y, vjp = jax.vjp(f, x, z)
  vjp(x)  # runs OK

  # note: vjp_aux_args is _not_ empty under `@jax.jit`
  vjp_pure, *vjp_aux_args = jax.closure_convert(vjp, x)
  g = vjp_pure(x, *vjp_aux_args)

  return y, g

call_pure_vjp(x, z)

Outputs:

google3/third_party/py/jax/_src/custom_derivatives.py in converted_fun(*args_hconsts)
   1177     consts = merge(closure_consts, hoisted_consts)
   1178     all_args, in_tree2 = tree_flatten(tuple(args))
-> 1179     assert in_tree == in_tree2
   1180     out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
   1181     return tree_unflatten(out_tree, out_flat)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.26.3
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (f142f8afe21bceb00fb495468aa0b5043e98c419)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.10.0-smp-1103.32.0.0', version='#1 [v5.10.0-1103.32.0.0] SMP @1721941885', machine='x86_64')

hbq1 avatar Sep 12 '24 09:09 hbq1