equinox icon indicating copy to clipboard operation
equinox copied to clipboard

eqx.filter_hessian gives incorrect result for jax.nn.relu

Open SimonKoop opened this issue 1 year ago • 1 comments

The following code

import jax
from jax import numpy as jnp
!pip install equinox
import equinox as eqx

print(f"{eqx.filter_hessian(jax.nn.relu)(jnp.ones(()))=}")
print(f"{jax.hessian(jax.nn.relu)(jnp.ones(()))=}")

results in

eqx.filter_hessian(jax.nn.relu)(jnp.ones(()))=Array(1., dtype=float32)
jax.hessian(jax.nn.relu)(jnp.ones(()))=Array(0., dtype=float32)

Here, the result of jax.hessian is correct but that of eqx.filter_hessian is not.

I tested this code in Google Colab. The version of Equinox is 0.11.4, that of JAX is 0.4.26, that of Python is 3.10.12.

SimonKoop avatar Aug 16 '24 13:08 SimonKoop

Looks like this is fixed on main, I see 0.0 as the output for jax and equinox on main. I believe this was a similar error in the jacfwd to this: https://github.com/patrick-kidger/equinox/pull/734.

lockwo avatar Aug 16 '24 16:08 lockwo