equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Static Numpy arrays in Modules

Open SNMS95 opened this issue 5 months ago • 1 comments

Hey,

I wanted to use equinox to hold two types of parameters: one that requires_grad (a similar terminology to PyTorch) and others which do not. I tried to do this with the following code

from jaxtyping import PyTree, Array
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
from typing import Union, Callable, Any


class Parameter(eqx.Module):
    name: str
    current_val: Union[PyTree[Array], float]
    requires_grad: bool

    def __init__(self, name: str, current_val: Union[PyTree[Array], float],
                 requires_grad: bool):
        self.name = name
        self.requires_grad = requires_grad

        # Ensure current_val is a JAX array or PyTree of JAX arrays
        if requires_grad:
            self.current_val = jax.tree.map(jnp.asarray, current_val)
        else:
            current_val =jax.tree.map(np.asarray, current_val)
            self.current_val = current_val

However, when taking gradient w.r.t the module, I get gradients in all the parameters (Due to JAX's promotion). What I want is to prevent them from participating the AD graph entirely. I tried using eqx.field(static=True) (As TrainableParameter & StaticParameter classes) but it raises the error that I am setting JAX Arrays as static even for the numpy case. So I am worried about using it. The aim is to have a collection of these Parameters (Again an eqx module), where I can selectively choose the Parameters that can participate in AD [Similar to marking in PyTorch].

  1. Should I be worried about the warning?
  2. Is there a better way to (perhaps with stop_gradient) to achieve this effect i.e., a collection of trainable and non-trainable parameters (still numerical arrays) in one module ?

SNMS95 avatar Aug 13 '25 14:08 SNMS95

Yup, stop_gradient is the way to do this :)

patrick-kidger avatar Aug 13 '25 15:08 patrick-kidger