equinox icon indicating copy to clipboard operation
equinox copied to clipboard

eqx.filter breaks type annotations and post init checks

Open norpadon opened this issue 6 months ago • 3 comments

Since exq.filter replaces non-matching leaves with None, all non-static module attributes should be annotated as T | None. Also, checks in __post_init__ should assume that attributes can be None. This results in tons of boilerplate, and makes using Jaxtyping less convenient. Is there some idiomatic way around this issue? What is the recommended solution?

norpadon avatar Jul 18 '25 00:07 norpadon

I'm afraid it gets worse than that! For example, jax.tree.map(lambda _: object(), some_module) will replace all attributes with an object(). Or to use another common example, jax.vmap(..., in_axes=..., out_axes=...) require PyTree structures that match the inputs and outputs, whilst having int or None leaves, and so in this case (usually via jax.tree.map and eqx.tree_at) one must construct modules whose leaves are integers or Nones.

In general, PyTrees may find any type being placed in their leaves: many kinds of manipulations rely on this ability, a few of which you've seen above.

So far none of this is specific to Equinox, this is all just true of PyTrees in general.

What this means is that when we manipulate a module as a pytree, we stop thinking about it as trying to maintain any of its invariants (like the types of its leaves, being able to call it, etc). At this point it is simply a collection of nested containers.

Eventually it might return to feeling like a normal module (in particular, if performing equinox.partition followed by equinox.combine), but at other times we'll just continue to treat it as nothing more than a conveniently-structured collection.

I hope that helps!

patrick-kidger avatar Jul 18 '25 09:07 patrick-kidger

Thanks for clarifying! So do I understand you correctly that tree_at/filter don't call post_init after surgery?

norpadon avatar Jul 18 '25 10:07 norpadon

Correct!

patrick-kidger avatar Jul 18 '25 10:07 patrick-kidger