equinox icon indicating copy to clipboard operation
equinox copied to clipboard

gradient with sequential and jax.nn.<activation functions>

Open mfouesneau opened this issue 2 months ago • 4 comments

Hi Patrick,

I am trying to use sequential to make my code clean, and I run into an issue when I try to get gradients

I put a simple example below

input_n = 100
key = jax.random.PRNGKey(42)

model = eqx.nn.Sequential([
    eqx.nn.Linear(input_n, 1, key=key),
    eqx.nn.Lambda(eqx.nn.PReLU())
])

@eqx.filter_jit
def loss(model, x):
    y = jax.vmap(model)(x)
    return jnp.sum(y)

x = jax.random.normal(key, (1, input_n))
jax.vmap(model)(x)    # ok
value_and_grad(loss)(model, x)   # error

This version works but if I replace the activation function by anything e.g. jax.nn.relu or jax.nn.gelu

TypeError: Argument '<function gelu at 0x7b51a1fd23b0>' of type <class 'function'> is not a valid JAX type.

I looked at your test_nn, which actually does not test the gradient call using Lambda, I search the examples, and the FAQ#typeerror-not-a-valid-jax-type but I do not understand the issue the error message is too cryptic to me.

Any ideas what's going wrong here?

mfouesneau avatar Apr 16 '24 13:04 mfouesneau