equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Shouldn't the wrapped function in eqx.Lambda be treated as static?

Open adonath opened this issue 9 months ago • 5 comments

I have noticed the following behavior of the eqx.Lambda module:

import jax
import equinox as eqx

layer = eqx.nn.Lambda(jax.nn.relu)
jax.tree.flatten(layer)

Which shows:

([<jax._src.custom_derivatives.custom_jvp at 0x117487250>],
 PyTreeDef(CustomNode(Lambda[('fn',), (), ()], [*])))

So the wrapped function is treated as a children by default and is not static / aux data. I found this a surprising default behavior because the function typically never acts as a PyTree node. I can see it could be a callable node, but in this case it would rather be a module itself. Is there any specific reason why it shouldn't be treated as static by default?

adonath avatar Apr 19 '25 17:04 adonath

I just saw the corresponding FAQ entry: https://docs.kidger.site/equinox/faq/#typeerror-not-a-valid-jax-type

It is explains how to adjust the filtering, however I would see the function as part of the "meta data" of the module just like other static attributes (such as https://github.com/patrick-kidger/equinox/blob/main/equinox/nn/_linear.py#L18).

adonath avatar Apr 19 '25 17:04 adonath

Indeed, callable nodes are possible, so for this reason it has to be a dynamic attribute. :)

patrick-kidger avatar Apr 21 '25 13:04 patrick-kidger

Thanks @patrick-kidger, for the clarification! I would guess the case of the callable not being a node is more common. But either behavior seems fine. Maybe adding one sentence to the docstring clarifying fn is treated as a node is a good idea?

adonath avatar Apr 21 '25 16:04 adonath

Sounds reasonable to me! Happy to take a PR on this.

patrick-kidger avatar Apr 21 '25 18:04 patrick-kidger

See https://github.com/patrick-kidger/equinox/pull/1009

adonath avatar Apr 21 '25 18:04 adonath