equinox
equinox copied to clipboard
eqx.filter_hessian gives incorrect result for jax.nn.relu
The following code
import jax
from jax import numpy as jnp
!pip install equinox
import equinox as eqx
print(f"{eqx.filter_hessian(jax.nn.relu)(jnp.ones(()))=}")
print(f"{jax.hessian(jax.nn.relu)(jnp.ones(()))=}")
results in
eqx.filter_hessian(jax.nn.relu)(jnp.ones(()))=Array(1., dtype=float32)
jax.hessian(jax.nn.relu)(jnp.ones(()))=Array(0., dtype=float32)
Here, the result of jax.hessian is correct but that of eqx.filter_hessian is not.
I tested this code in Google Colab. The version of Equinox is 0.11.4, that of JAX is 0.4.26, that of Python is 3.10.12.
Looks like this is fixed on main, I see 0.0 as the output for jax and equinox on main. I believe this was a similar error in the jacfwd to this: https://github.com/patrick-kidger/equinox/pull/734.