equinox
equinox copied to clipboard
"JAX array is set a static" warning is raised unwantedly
As of Equinox 0.11.6 and https://github.com/patrick-kidger/equinox/pull/800, the following MWE raises a UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
import numpy as np
import equinox as eqx
class Foo(eqx.Module):
x: tuple[int, int] = eqx.field(static=True)
def add_one(self):
x_as_np = np.asarray(self.x)
return Foo(tuple(x_as_np+1))
x = (3, 2)
foo = Foo(x)
foo.add_one()
# UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.
This means that one cannot perform numpy operations (which is often simpler than writing them in plain python) on a static attribute. This is a use-case we have in dynamiqs, see for instance the method __mul__
of this class which represents an array in diagonal (DIA) sparse format. Note that we intentionally use numpy
instead of jax.numpy
to have "static" logic.