Static Numpy arrays in Modules
Hey,
I wanted to use equinox to hold two types of parameters: one that requires_grad (a similar terminology to PyTorch) and others which do not.
I tried to do this with the following code
from jaxtyping import PyTree, Array
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
from typing import Union, Callable, Any
class Parameter(eqx.Module):
name: str
current_val: Union[PyTree[Array], float]
requires_grad: bool
def __init__(self, name: str, current_val: Union[PyTree[Array], float],
requires_grad: bool):
self.name = name
self.requires_grad = requires_grad
# Ensure current_val is a JAX array or PyTree of JAX arrays
if requires_grad:
self.current_val = jax.tree.map(jnp.asarray, current_val)
else:
current_val =jax.tree.map(np.asarray, current_val)
self.current_val = current_val
However, when taking gradient w.r.t the module, I get gradients in all the parameters (Due to JAX's promotion). What I want is to prevent them from participating the AD graph entirely. I tried using eqx.field(static=True) (As TrainableParameter & StaticParameter classes) but it raises the error that I am setting JAX Arrays as static even for the numpy case. So I am worried about using it. The aim is to have a collection of these Parameters (Again an eqx module), where I can selectively choose the Parameters that can participate in AD [Similar to marking in PyTorch].
- Should I be worried about the warning?
- Is there a better way to (perhaps with
stop_gradient) to achieve this effect i.e., a collection of trainable and non-trainable parameters (still numerical arrays) in one module ?
Yup, stop_gradient is the way to do this :)