eqx.filter breaks type annotations and post init checks
Since exq.filter replaces non-matching leaves with None, all non-static module attributes should be annotated as T | None. Also, checks in __post_init__ should assume that attributes can be None. This results in tons of boilerplate, and makes using Jaxtyping less convenient. Is there some idiomatic way around this issue? What is the recommended solution?
I'm afraid it gets worse than that! For example, jax.tree.map(lambda _: object(), some_module) will replace all attributes with an object(). Or to use another common example, jax.vmap(..., in_axes=..., out_axes=...) require PyTree structures that match the inputs and outputs, whilst having int or None leaves, and so in this case (usually via jax.tree.map and eqx.tree_at) one must construct modules whose leaves are integers or Nones.
In general, PyTrees may find any type being placed in their leaves: many kinds of manipulations rely on this ability, a few of which you've seen above.
So far none of this is specific to Equinox, this is all just true of PyTrees in general.
What this means is that when we manipulate a module as a pytree, we stop thinking about it as trying to maintain any of its invariants (like the types of its leaves, being able to call it, etc). At this point it is simply a collection of nested containers.
Eventually it might return to feeling like a normal module (in particular, if performing equinox.partition followed by equinox.combine), but at other times we'll just continue to treat it as nothing more than a conveniently-structured collection.
I hope that helps!
Thanks for clarifying! So do I understand you correctly that tree_at/filter don't call post_init after surgery?
Correct!