Dynamically created containers
I want to use equinox as a container class for some parameters.
For now, lets say that each parameter, I can enforce some order i.e. I know the field before-hand and can thus, handle then in the __init__(). For example:
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
class ScalarParameter(eqx.Module, strict=True):
"""Stores a scalar variable with optional bounding."""
val: jnp.ndarray = eqx.field(converter=jnp.array)
enforce_bounds: bool
bounding_fn: str
lb: float = eqx.field(static=True)
ub: float = eqx.field(static=True)
def __init__(self, scalar_data: dict):
self.val = scalar_data["init"]
self.enforce_bounds = scalar_data.get("enforce_bounds", False)
self.bounding_fn = scalar_data.get("bounding_fn", "clip")
self.lb = scalar_data.get("lb", -np.inf)
self.ub = scalar_data.get("ub", np.inf)
def get_bounded_value(self) -> float:
"""Apply bounding function to enforce constraints."""
if not self.enforce_bounds:
return self.val
elif self.bounding_fn == "sigmoid":
return self.lb + (self.ub - self.lb) * jax.nn.sigmoid(self.val)
elif self.bounding_fn == "clip":
return jnp.clip(self.val, self.lb, self.ub)
else:
raise ValueError(f"Unknown bounding function: {self.bounding_fn}")
Now, I want to make a collection of parameters and have some helper functions for easily accessing and changing these parameters.
(I know that equinox modules are immutable so I need to use eqx.tree_at).
class ScalarContainer(eqx.Module, strict=True):
"""Container for scalar parameters, allowing attribute access."""
def __init__(self, scalars: dict[str, dict]):
for name, data in scalars.items():
object.__setattr__(self, name, ScalarParameter(data))
def get_scalar(self, name: str) -> jnp.ndarray:
"""Get the bounded value of a scalar."""
if hasattr(self, name):
return getattr(self, name).get_bounded_value()
raise KeyError(f"Scalar '{name}' not found.")
def set_scalar(self, name: str, new_value: float) -> 'ScalarContainer':
"""Set the value of a scalar parameter."""
path = lambda tree: getattr(tree, name)
updated_container = eqx.tree_at(path, self, new_value)
return updated_container
def __getattr__(self, name):
"""Allow direct access to `.val` for scalars."""
if name in self.__dict__:
return getattr(self, name).val
raise AttributeError(f"Scalar '{name}' not found.")
But when I try setting a value, it gives me an error that AttributeError: Scalar 'learning_rate0' not found.. This attribute is however present in the __dict__ of the self object.
outer_params = ScalarContainer({
"learning_rate0": {"init": 0.01, "lb": 0.001, "ub": 0.1, "enforce_bounds": True, "bounding_fn": "clip"},
"penalty0": {"init": 1.0, "lb": 0.0, "ub": 10.0, "enforce_bounds": True, "bounding_fn": "sigmoid"},
}
)
print(outer_params)
outer_params.set_scalar("learning_rate0", 0.05)
Is there any way to achieve this result ?
This bit:
class ScalarContainer(eqx.Module, strict=True):
def __init__(self, scalars: dict[str, dict]):
for name, data in scalars.items():
object.__setattr__(self, name, ScalarParameter(data))
is definitely not okay! Equinox modules are frozen dataclasses, and these require you to declare your attributes.
It's totally true that Python allows you to sneak past the immutability of frozen dataclasses by using object.__setattr__, but as you've observed, this tends to result in issues downstream. :)
(If you really need some kind of anything-goes then I recommend assigning a dictionary as a single attribute, and putting everything in that.)
i've a question, altering a well defined attribute (the one that we must list after the 'class' keyword) on eq.module allowed right with object.setattr