equinox
equinox copied to clipboard
gradient with sequential and jax.nn.<activation functions>
Hi Patrick,
I am trying to use sequential to make my code clean, and I run into an issue when I try to get gradients
I put a simple example below
input_n = 100
key = jax.random.PRNGKey(42)
model = eqx.nn.Sequential([
eqx.nn.Linear(input_n, 1, key=key),
eqx.nn.Lambda(eqx.nn.PReLU())
])
@eqx.filter_jit
def loss(model, x):
y = jax.vmap(model)(x)
return jnp.sum(y)
x = jax.random.normal(key, (1, input_n))
jax.vmap(model)(x) # ok
value_and_grad(loss)(model, x) # error
This version works but if I replace the activation function by anything e.g. jax.nn.relu or jax.nn.gelu
TypeError: Argument '<function gelu at 0x7b51a1fd23b0>' of type <class 'function'> is not a valid JAX type.
I looked at your test_nn, which actually does not test the gradient call using Lambda, I search the examples, and the FAQ#typeerror-not-a-valid-jax-type but I do not understand the issue the error message is too cryptic to me.
Any ideas what's going wrong here?