Shouldn't the wrapped function in eqx.Lambda be treated as static?
I have noticed the following behavior of the eqx.Lambda module:
import jax
import equinox as eqx
layer = eqx.nn.Lambda(jax.nn.relu)
jax.tree.flatten(layer)
Which shows:
([<jax._src.custom_derivatives.custom_jvp at 0x117487250>],
PyTreeDef(CustomNode(Lambda[('fn',), (), ()], [*])))
So the wrapped function is treated as a children by default and is not static / aux data. I found this a surprising default behavior because the function typically never acts as a PyTree node. I can see it could be a callable node, but in this case it would rather be a module itself. Is there any specific reason why it shouldn't be treated as static by default?
I just saw the corresponding FAQ entry: https://docs.kidger.site/equinox/faq/#typeerror-not-a-valid-jax-type
It is explains how to adjust the filtering, however I would see the function as part of the "meta data" of the module just like other static attributes (such as https://github.com/patrick-kidger/equinox/blob/main/equinox/nn/_linear.py#L18).
Indeed, callable nodes are possible, so for this reason it has to be a dynamic attribute. :)
Thanks @patrick-kidger, for the clarification! I would guess the case of the callable not being a node is more common. But either behavior seems fine. Maybe adding one sentence to the docstring clarifying fn is treated as a node is a good idea?
Sounds reasonable to me! Happy to take a PR on this.
See https://github.com/patrick-kidger/equinox/pull/1009