Understanding default field values when using a custom `__init__`
I have run into some (for me) unintuitive behaviour when using custom __init__ methods with equinox.Modules that have default values for the fields. The documentation suggests that the default __init__ just fills in the fields but it appears it also does some other magic under the hood:
import equinox as eqx
class State(eqx.Module):
x: jax.Array | None = None
s = State()
print(f"{s.x=}") # s.x=None
s = eqx.tree_at(lambda x: [x.x], s, [jnp.arange(2)])
print(f"{s.x=}") # s=s.x=Array([0, 1], dtype=int32)
So far, so good. But when I supply a custom __init__ method, this fails:
class StateWithInit(eqx.Module):
x: jax.Array | None = None
def __init__(self):
pass
s = StateWithInit()
print(f"{s.x=}") # s.x=None
s = eqx.tree_at(lambda x: [x.x], s, [jnp.arange(2)])
print(f"{s=}, {s.x=}")
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[33], line 23
20 s = StateWithInit()
21 print(f"{s=}, {s.x=}")
---> 23 s = eqx.tree_at(lambda x: [x.x], s, [jnp.arange(2)])
24 print(f"{s=}, {s.x=}")
File /scratch2/tilmant/Research/jax/jax_env/.pixi/envs/default/lib/python3.12/site-packages/equinox/_tree.py:237, in tree_at(where, pytree, replace, replace_fn, is_leaf)
235 count = node_replace_fns.count(node)
236 if count == 0:
--> 237 raise ValueError(
238 "`where` does not specify an element or elements of `pytree`."
239 )
240 elif count == 1:
241 pass
ValueError: `where` does not specify an element or elements of `pytree`.
This might be related to #301, since
s = State()
print(f"{s.__dict__=}") # s.__dict__={'x': None}
s = StateWithInit()
print(f"{s.__dict__=}") # s.__dict__={}
I originally ran into this when the default value for a field was turned from a function into a BoundMethod:
import equinox as eqx
from typing import Callable
class StateWithFunc(eqx.Module):
f: Callable = lambda x: x
s = StateWithFunc()
print(f"{s.f=}") # s.f=<function StateWithFunc.<lambda> at 0x7fcec04b4a40>
class StateWithFuncWithInit(eqx.Module):
f: Callable = lambda x: x
def __init__(self):
pass
s = StateWithFuncWithInit()
print(f"{s.f=}")
# s.f=BoundMethod(
# __func__=<function StateWithFuncWithInit.<lambda>>,
# __self__=StateWithFuncWithInit(
# f=BoundMethod(
# __func__=<function StateWithFuncWithInit.<lambda>>, __self__=<recursive>
# )
# )
# )
What is the recommended pattern here? Do not use default field values with __init__? Do not overwrite __init__ but use __post_init__ instead? Do it completely differently?
Edit: To clarify my use case, I am trying to do something like this:
class State(eqx.Module):
n_dim: int
h: Callable
f: Callable = lambda x: x
x: jax.Array | None = None
def __init__(self, n_dim, **kwargs):
self.n_dim = n_dim
if self.n_dim == 1:
self.h = lambda x: 2*x
else:
self.h = lambda x: jnp.sin(x)
# Initialise other fields with defaults or overwrite from kwargs
for f in dataclasses.fields(self):
if f.name in kwargs:
setattr(self, f.name, kwargs[f.name])
elif f.default is not dataclasses.MISSING:
setattr(self, f.name, f.default)
s = State(n_dim=1)
s = State(n_dim=2, h=lambda x: jnp.cos(x))
s = State(n_dim=2, f=lambda x: x**2)
s = eqx.tree_at(lambda x: [x.x], s, [jnp.arange(2)])
but the stuff at the end __init__ feels a bit hacky.
Right, so what you're seeing here is that in the first snippet, the autogenerated __init__ method will set self.x = None as an instance attribute.
Meanwhile in your second snippet, it remains only a class attribute. This is what you're noticing when you check __dict__. Operations like eqx.tree_at operate specifically on instance attributes.
Ideally we'd actually detect if any fields were only filled in at the class level, not the instance level, and then error out. Unfortunately, we have to still allow this case for backward compatibility!
In terms of a solution for you, then the usual approach is to consider class-level default values and custom __init__ methods to be mutually exclusive. Just set all the attributes during __init__! But you are also highlighting that this is a little non-ideal. If you think it would have helped you here, then I'd be happy to take a PR adding a warning about any instance-level attributes not being filled in, alongside the existing error here:
https://github.com/patrick-kidger/equinox/blob/6ad63cbafc0796818b74b5e0b550d212903c905a/equinox/_module/_module.py#L421-L429
Can use simple-pytrre to create pytree and then use it in eqx.nn.Module. its very easy to create custom pytree this way