equinox icon indicating copy to clipboard operation
equinox copied to clipboard

filter_spec must be static

Open HGangloff opened this issue 7 months ago • 2 comments

Hi,

I have some jitted code that calls eqx.partition function and the code crashes if filter_spec is not declared as static. I am wondering if this is on purpose or not since the doc does not seem warn the user about that ? If so the error message could be improved I believe.

MWE:

import jax
import jax.numpy as jnp
import equinox as eqx

@jax.jit
def partition(params, params_mask):
    return eqx.partition(params, filter_spec=params_mask)

params = ((1., 2.), 3.)
params_mask = ((True, False), False)

will crash with ValueError: filter_spec must consist of booleans and callables only. ?!

On the other hand, the two examples below work well.

@eqx.filter_jit
def partition(params, params_mask):
    return eqx.partition(params, filter_spec=params_mask)
@partial(jax.jit, static_argnames=["params_mask"])
def partition(params, params_mask):
    return eqx.partition(params, filter_spec=params_mask)

HGangloff avatar Jun 13 '25 13:06 HGangloff

Ah, interesting! I've not seen filter_spec passed across a JIT boundary like this before. Usually this is something created before a JIT boundary, so that params can be split into the dynamic and static pieces:

@partial(jax.jit, static_argnames=["static"])
def run(dynamic, static):
    ...

dynamic, static = eqx.partition(params, filter_spec=params_mask)
run(dynamic, static)

(This is pretty much what eqx.filter_jit does for you automatically under-the-hood.)

On improving the error message, I agree that sounds reasonable! Probably we can detect when it is a traced value, and issue a slightly different message then? If so then I'd be happy to take a PR on this.

patrick-kidger avatar Jun 14 '25 21:06 patrick-kidger

I know this is not the most common usage of eqx.partition but I want it inside a jitted loop because I need an alternate solver in a case where I cannot efficiently use the alternate options provided by optax (through their multi_transform essentially). It looks like it could work if my parameter masks are static!

I inserted the corresponding check to raise a more informative issue in _make_filter_tree:

if isinstance(mask, jax.core.Tracer):
    raise ValueError(
        "`filter_spec` leaf values cannot be traced arrays."
    )

But cannot make a working unit test for it, since the tests run with runtime typecheckers and jaxtyping catches the tracer value before I do 😄 :

>       filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree)
E       jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of equinox._filters._make_filter_tree.<locals>._filter_tree.
E       The problem arose whilst typechecking parameter 'mask'.
E       Actual value: bool[]
E       Expected type: bool | collections.abc.Callable[[Any], bool].

HGangloff avatar Jun 16 '25 07:06 HGangloff

Closing as resolved in #1038!

patrick-kidger avatar Jul 07 '25 20:07 patrick-kidger