equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Help understanding this behavior with pure jax jit compatibility

Open Rumoa opened this issue 2 months ago • 1 comments

I have a question regarding the compatibility of jax.jit with equinox models. In the quick example of the equinox repository, we are told that models are fully compatible with normal Jax operations, but in the following example, we cannot use jax.jit and only eqx.filter_jit

import equinox as eqx
import jax
import jax.numpy as jnp

key = jax.random.key(0)
m1 = eqx.nn.MLP(4, 4, 64, 2, jax.nn.gelu, key=key)


@jax.jit
def f(model, x):
    return model(x)


f(m1, jnp.arange(4))

we get the following error:

TypeError: Error interpreting argument to <function f at 0x7da688a4b2e0> as an abstract array. The problematic value is of type <class 'function'> and was passed to the function at path model.activation.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

I understand that is because of the activation function of the mlp, and if we use filter_jit, we do not have this problem, but I was wondering if there is a workaround on how to use jax.jit in these cases.

Thank you for developing this library, alongside diffrax, it's been very helpful in my research projects.

Rumoa avatar Oct 21 '25 11:10 Rumoa

Thankyou, I'm glad you like Equinox and Diffrax!

So the easiest way to do this is to use eqx.{partition, combine} to split your model into the static and non-static parts. See Option 1 here.

I hope that helps :)

patrick-kidger avatar Oct 21 '25 13:10 patrick-kidger