equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Dynamically created containers

Open SNMS95 opened this issue 10 months ago • 2 comments

I want to use equinox as a container class for some parameters. For now, lets say that each parameter, I can enforce some order i.e. I know the field before-hand and can thus, handle then in the __init__(). For example:

import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx

class ScalarParameter(eqx.Module, strict=True):
    """Stores a scalar variable with optional bounding."""
    val: jnp.ndarray = eqx.field(converter=jnp.array)
    enforce_bounds: bool
    bounding_fn: str
    lb: float = eqx.field(static=True)
    ub: float = eqx.field(static=True)

    def __init__(self, scalar_data: dict):
        self.val = scalar_data["init"]
        self.enforce_bounds = scalar_data.get("enforce_bounds", False)
        self.bounding_fn = scalar_data.get("bounding_fn", "clip")
        self.lb = scalar_data.get("lb", -np.inf)
        self.ub = scalar_data.get("ub", np.inf)

    def get_bounded_value(self) -> float:
        """Apply bounding function to enforce constraints."""
        if not self.enforce_bounds:
            return self.val
        elif self.bounding_fn == "sigmoid":
            return self.lb + (self.ub - self.lb) * jax.nn.sigmoid(self.val)
        elif self.bounding_fn == "clip":
            return jnp.clip(self.val, self.lb, self.ub)
        else:
            raise ValueError(f"Unknown bounding function: {self.bounding_fn}")

Now, I want to make a collection of parameters and have some helper functions for easily accessing and changing these parameters. (I know that equinox modules are immutable so I need to use eqx.tree_at).

class ScalarContainer(eqx.Module, strict=True):
    """Container for scalar parameters, allowing attribute access."""

    def __init__(self, scalars: dict[str, dict]):
        for name, data in scalars.items():
            object.__setattr__(self, name, ScalarParameter(data))

    def get_scalar(self, name: str) -> jnp.ndarray:
        """Get the bounded value of a scalar."""
        if hasattr(self, name):
            return getattr(self, name).get_bounded_value()
        raise KeyError(f"Scalar '{name}' not found.")

    def set_scalar(self, name: str, new_value: float) -> 'ScalarContainer':
        """Set the value of a scalar parameter."""       
        path = lambda tree: getattr(tree, name)
        updated_container = eqx.tree_at(path, self, new_value)
        
        return updated_container

    def __getattr__(self, name):
        """Allow direct access to `.val` for scalars."""
        if name in self.__dict__:
            return getattr(self, name).val
        raise AttributeError(f"Scalar '{name}' not found.")

But when I try setting a value, it gives me an error that AttributeError: Scalar 'learning_rate0' not found.. This attribute is however present in the __dict__ of the self object.

outer_params = ScalarContainer({
        "learning_rate0": {"init": 0.01, "lb": 0.001, "ub": 0.1, "enforce_bounds": True, "bounding_fn": "clip"},
        "penalty0": {"init": 1.0, "lb": 0.0, "ub": 10.0, "enforce_bounds": True, "bounding_fn": "sigmoid"},
    }
)

print(outer_params)
outer_params.set_scalar("learning_rate0", 0.05)

Is there any way to achieve this result ?

SNMS95 avatar Mar 13 '25 19:03 SNMS95

This bit:

class ScalarContainer(eqx.Module, strict=True):
    def __init__(self, scalars: dict[str, dict]):
        for name, data in scalars.items():
            object.__setattr__(self, name, ScalarParameter(data))

is definitely not okay! Equinox modules are frozen dataclasses, and these require you to declare your attributes.

It's totally true that Python allows you to sneak past the immutability of frozen dataclasses by using object.__setattr__, but as you've observed, this tends to result in issues downstream. :)

(If you really need some kind of anything-goes then I recommend assigning a dictionary as a single attribute, and putting everything in that.)

patrick-kidger avatar Mar 13 '25 23:03 patrick-kidger

i've a question, altering a well defined attribute (the one that we must list after the 'class' keyword) on eq.module allowed right with object.setattr

carlesoctav avatar Aug 24 '25 07:08 carlesoctav