Enzyme-JAX
Enzyme-JAX copied to clipboard
Add support for `jacrev`, `jacfwd`, `hessian`, `vmap`
Extending the tests in https://github.com/EnzymeAD/Enzyme-JAX/blob/d9e2ae086aaac98a0e052a6d8e6d2ed163c00e6a/test/test.py#L68 with
> x = jax.jacrev(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Batching rule for 'enzyme_rev' not implemented
> x = jax.jacfwd(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Batching rule for 'enzyme_fwd' not implemented
> x = jax.hessian(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Differentiation rule for 'enzyme_aug' not implemented
> x = jax.jit(jax.vmap(lambda x: add_one(x, jnp.array([1., 2., 3.]))))(jnp.array([jnp.array([1., 2., 3.])]*5))
NotImplementedError: Batching rule for 'enzyme_primal' not implemented
Are jax.grad and jax.value_and_grad already supported?
Hi @croci,
I believe so. I just tested it by adding:
x = jax.grad(add_one)(.1, .1)
x = jax.value_and_grad(add_one)(1., 2.)
to the end of the test.py file.
--Jesse
@martinjm97 ok great! I couldn't get it to work on the pip version, but that is v0.0.4 (and I confess I didn't try too hard). Thanks a lot!