filter_spec must be static
Hi,
I have some jitted code that calls eqx.partition function and the code crashes if filter_spec is not declared as static. I am wondering if this is on purpose or not since the doc does not seem warn the user about that ? If so the error message could be improved I believe.
MWE:
import jax
import jax.numpy as jnp
import equinox as eqx
@jax.jit
def partition(params, params_mask):
return eqx.partition(params, filter_spec=params_mask)
params = ((1., 2.), 3.)
params_mask = ((True, False), False)
will crash with ValueError: filter_spec must consist of booleans and callables only. ?!
On the other hand, the two examples below work well.
@eqx.filter_jit
def partition(params, params_mask):
return eqx.partition(params, filter_spec=params_mask)
@partial(jax.jit, static_argnames=["params_mask"])
def partition(params, params_mask):
return eqx.partition(params, filter_spec=params_mask)
Ah, interesting! I've not seen filter_spec passed across a JIT boundary like this before. Usually this is something created before a JIT boundary, so that params can be split into the dynamic and static pieces:
@partial(jax.jit, static_argnames=["static"])
def run(dynamic, static):
...
dynamic, static = eqx.partition(params, filter_spec=params_mask)
run(dynamic, static)
(This is pretty much what eqx.filter_jit does for you automatically under-the-hood.)
On improving the error message, I agree that sounds reasonable! Probably we can detect when it is a traced value, and issue a slightly different message then? If so then I'd be happy to take a PR on this.
I know this is not the most common usage of eqx.partition but I want it inside a jitted loop because I need an alternate solver in a case where I cannot efficiently use the alternate options provided by optax (through their multi_transform essentially). It looks like it could work if my parameter masks are static!
I inserted the corresponding check to raise a more informative issue in _make_filter_tree:
if isinstance(mask, jax.core.Tracer):
raise ValueError(
"`filter_spec` leaf values cannot be traced arrays."
)
But cannot make a working unit test for it, since the tests run with runtime typecheckers and jaxtyping catches the tracer value before I do 😄 :
> filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree)
E jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of equinox._filters._make_filter_tree.<locals>._filter_tree.
E The problem arose whilst typechecking parameter 'mask'.
E Actual value: bool[]
E Expected type: bool | collections.abc.Callable[[Any], bool].
Closing as resolved in #1038!