equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Understanding default field values when using a custom `__init__`

Open tilmantroester opened this issue 7 months ago • 2 comments

I have run into some (for me) unintuitive behaviour when using custom __init__ methods with equinox.Modules that have default values for the fields. The documentation suggests that the default __init__ just fills in the fields but it appears it also does some other magic under the hood:

import equinox as eqx

class State(eqx.Module):
    x: jax.Array | None = None

s = State()
print(f"{s.x=}")   # s.x=None

s = eqx.tree_at(lambda x: [x.x], s, [jnp.arange(2)])
print(f"{s.x=}")   # s=s.x=Array([0, 1], dtype=int32)

So far, so good. But when I supply a custom __init__ method, this fails:

class StateWithInit(eqx.Module):
    x: jax.Array | None = None

    def __init__(self):
        pass

s = StateWithInit()
print(f"{s.x=}")  # s.x=None

s = eqx.tree_at(lambda x: [x.x], s, [jnp.arange(2)])
print(f"{s=}, {s.x=}")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[33], line 23
     20 s = StateWithInit()
     21 print(f"{s=}, {s.x=}")
---> 23 s = eqx.tree_at(lambda x: [x.x], s, [jnp.arange(2)])
     24 print(f"{s=}, {s.x=}")

File /scratch2/tilmant/Research/jax/jax_env/.pixi/envs/default/lib/python3.12/site-packages/equinox/_tree.py:237, in tree_at(where, pytree, replace, replace_fn, is_leaf)
    235 count = node_replace_fns.count(node)
    236 if count == 0:
--> 237     raise ValueError(
    238         "`where` does not specify an element or elements of `pytree`."
    239     )
    240 elif count == 1:
    241     pass

ValueError: `where` does not specify an element or elements of `pytree`.

This might be related to #301, since

s = State()
print(f"{s.__dict__=}")  # s.__dict__={'x': None}

s = StateWithInit()
print(f"{s.__dict__=}")  # s.__dict__={}

I originally ran into this when the default value for a field was turned from a function into a BoundMethod:

import equinox as eqx
from typing import Callable

class StateWithFunc(eqx.Module):
    f: Callable = lambda x: x

s = StateWithFunc()
print(f"{s.f=}")  # s.f=<function StateWithFunc.<lambda> at 0x7fcec04b4a40>


class StateWithFuncWithInit(eqx.Module):
    f: Callable = lambda x: x

    def __init__(self):
        pass

s = StateWithFuncWithInit()
print(f"{s.f=}")
# s.f=BoundMethod(
#   __func__=<function StateWithFuncWithInit.<lambda>>,
#   __self__=StateWithFuncWithInit(
#     f=BoundMethod(
#       __func__=<function StateWithFuncWithInit.<lambda>>, __self__=<recursive>
#     )
#   )
# )

What is the recommended pattern here? Do not use default field values with __init__? Do not overwrite __init__ but use __post_init__ instead? Do it completely differently?

Edit: To clarify my use case, I am trying to do something like this:

class State(eqx.Module):
    n_dim: int
    h: Callable
    f: Callable = lambda x: x

    x: jax.Array | None = None

    def __init__(self, n_dim, **kwargs):
        self.n_dim = n_dim
        if self.n_dim == 1:
            self.h = lambda x: 2*x
        else:
            self.h = lambda x: jnp.sin(x)

        # Initialise other fields with defaults or overwrite from kwargs
        for f in dataclasses.fields(self):
            if f.name in kwargs:
                setattr(self, f.name, kwargs[f.name])
            elif f.default is not dataclasses.MISSING:
                setattr(self, f.name, f.default)

s = State(n_dim=1)
s = State(n_dim=2, h=lambda x: jnp.cos(x))
s = State(n_dim=2, f=lambda x: x**2)
s = eqx.tree_at(lambda x: [x.x], s, [jnp.arange(2)])

but the stuff at the end __init__ feels a bit hacky.

tilmantroester avatar Jun 13 '25 11:06 tilmantroester

Right, so what you're seeing here is that in the first snippet, the autogenerated __init__ method will set self.x = None as an instance attribute.

Meanwhile in your second snippet, it remains only a class attribute. This is what you're noticing when you check __dict__. Operations like eqx.tree_at operate specifically on instance attributes.

Ideally we'd actually detect if any fields were only filled in at the class level, not the instance level, and then error out. Unfortunately, we have to still allow this case for backward compatibility!

In terms of a solution for you, then the usual approach is to consider class-level default values and custom __init__ methods to be mutually exclusive. Just set all the attributes during __init__! But you are also highlighting that this is a little non-ideal. If you think it would have helped you here, then I'd be happy to take a PR adding a warning about any instance-level attributes not being filled in, alongside the existing error here:

https://github.com/patrick-kidger/equinox/blob/6ad63cbafc0796818b74b5e0b550d212903c905a/equinox/_module/_module.py#L421-L429

patrick-kidger avatar Jun 14 '25 22:06 patrick-kidger

Can use simple-pytrre to create pytree and then use it in eqx.nn.Module. its very easy to create custom pytree this way

ak24watch avatar Aug 03 '25 01:08 ak24watch