equinox icon indicating copy to clipboard operation
equinox copied to clipboard

"JAX array is set a static" warning is raised unwantedly

Open gautierronan opened this issue 5 months ago • 5 comments

As of Equinox 0.11.6 and https://github.com/patrick-kidger/equinox/pull/800, the following MWE raises a UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.

import numpy as np
import equinox as eqx

class Foo(eqx.Module):
    x: tuple[int, int] = eqx.field(static=True)

    def add_one(self):
        x_as_np = np.asarray(self.x)
        return Foo(tuple(x_as_np+1))

x = (3, 2)
foo = Foo(x)
foo.add_one()
# UserWarning: A JAX array is being set as static! This can result in unexpected behavior and is usually a mistake to do.

This means that one cannot perform numpy operations (which is often simpler than writing them in plain python) on a static attribute. This is a use-case we have in dynamiqs, see for instance the method __mul__ of this class which represents an array in diagonal (DIA) sparse format. Note that we intentionally use numpy instead of jax.numpy to have "static" logic.

gautierronan avatar Sep 24 '24 17:09 gautierronan