Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

Add support for `jacrev`, `jacfwd`, `hessian`, `vmap`

Open martinjm97 opened this issue 2 years ago • 3 comments

Extending the tests in https://github.com/EnzymeAD/Enzyme-JAX/blob/d9e2ae086aaac98a0e052a6d8e6d2ed163c00e6a/test/test.py#L68 with

> x = jax.jacrev(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Batching rule for 'enzyme_rev' not implemented
> x = jax.jacfwd(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Batching rule for 'enzyme_fwd' not implemented
> x = jax.hessian(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.]))
NotImplementedError: Differentiation rule for 'enzyme_aug' not implemented
> x = jax.jit(jax.vmap(lambda x: add_one(x, jnp.array([1., 2., 3.]))))(jnp.array([jnp.array([1., 2., 3.])]*5))
NotImplementedError: Batching rule for 'enzyme_primal' not implemented

martinjm97 avatar Jul 30 '23 18:07 martinjm97

Are jax.grad and jax.value_and_grad already supported?

croci avatar Nov 03 '23 22:11 croci

Hi @croci,

I believe so. I just tested it by adding:

x = jax.grad(add_one)(.1, .1)
x = jax.value_and_grad(add_one)(1., 2.)

to the end of the test.py file.

--Jesse

martinjm97 avatar Nov 05 '23 03:11 martinjm97

@martinjm97 ok great! I couldn't get it to work on the pip version, but that is v0.0.4 (and I confess I didn't try too hard). Thanks a lot!

croci avatar Nov 05 '23 03:11 croci