What is the difference between jax.lax.stop_gradient and eqx.static_field=True ?
Hello, I'm going to make a ODE simulator for Bloch Equation, I would like to define a module which has some attributes of array type,
class(eqx.Module): array1:Array["x y z"] array2: Array ["x y z"]
what I would like to do is : When I'm initializing with a np.array, then it is static and do not need the gradient, when initializing with a jnp.array then it is marked dynamic in jax (requires gradient). Currently I read the whole document, and there are two ways doing so : eqx.static_field = True or jax.lax.stop_gradient , which one should I choose and how should I wrap this function into the initialization?
I've also checked some issues, it is said that eqx.static_field = True this way is dangerous, what should I do to realize it?
I think what you want here is custom filtering, e.g. as done in this example (look for the filter_spec definition in the main function).
Would that work for your use case? It would allow you to control what ends up in the dynamic and static partition of your model, so you can compute gradients only with respect to the arrays that you want gradients for.
You want jax.lax.stop_gradient. This will block autodiff from operating through that variable.
eqx.field(static=True) does something totally different -- it marks a dataclass field as not being part of the pytree structure.
As for how to use jax.lax.stop_gradient, usually something like this:
class AddWithoutGradient(eqx.Module):
x: jax.Array
def __call__(self, y: jax.Array):
x = jax.lax.stop_gradient(x)
return x + y
jax.lax.stop_gradient "gets lost" when the input to a module is wrapped, and seems to require doing the wrapping directly inside the bound methods. This is why I didn't recommend it yesterday - although we might be able to fix it though the module metaclass?
Take this example, which I would naively expect to work:
import equinox as eqx
import jax
import wadler_lindig as wl
class ArrayContainer(eqx.Module):
x: jax.Array
def parabolic_loss(tree):
return jnp.sum(tree.x**2)
some_vector = 2.0 * jax.numpy.ones(3)
a = ArrayContainer(some_vector)
wl.pprint(jax.grad(parabolic_loss)(a), short_arrays=False) # Has gradient
b = ArrayContainer(jax.lax.stop_gradient(some_vector))
wl.pprint(jax.grad(parabolic_loss)(b), short_arrays=False) # Also has gradient
@johannahaffner -- so this is a misunderstanding of how jax.lax.stop_gradient works. It has to be called from within the region that has jax.grad applied to it. It doesn't control a property of the array it is called on -- it's a function that is called within a computation graph (and this function is the identity function with zero gradient).
Thank you for clearing that up!
FWIW, here are some collected related issues, these four are directly relevant
https://github.com/patrick-kidger/equinox/issues/909, https://github.com/patrick-kidger/equinox/issues/710 https://github.com/patrick-kidger/equinox/issues/31, including this comment, https://github.com/patrick-kidger/equinox/issues/214
Hello , Thank you for all the feedback and comments! @johannahaffner @patrick-kidger. I've checked all the aforementioned issues and comments, and think about how should I organize the software structure: So what I'm currently making is a very complicated simulation software describing complicated physics. So to manage the software, i plan to define the variable which is jnp.array type has gradients, other type (numpy, python built-in type) are non-trainable for managing the software.
My naive ideas after reviewing all the posts:
$Option 1$ : Using filter_spec: The idea is, when designing the software, making all the fixed structure, (str those datatype) as static field (they will never need gradients!) so that static arguments makes performance slightly enhanced. Those attributes might need gradients, I set as Union[Array[""], float] (for a Scalar) or Array. For scalar , it is easy to handle because Equinox treat python built-in type as static. For Array type, my solution using filter_spec is firstly mark all the leaves as non-trainable, then specify all the jnp.array type in the attributes as trainable.
$Option 2$: Make the attribute a $@property$ or using is_leaf, as mentioned in #31, like requrie_grad=False in pytorch
However I'm not sure using $@property$ this method, whether it will cause some overhead in diffrax ,as far as I knew it is @property is not jit compiled.
Not sure my understanding is correct or not, but thanks for all the reply again!
I think option 1 sounds like a good choice! Although FWIW you probably won't need to use static fields at all -- using eqx.filter_{jit, grad, ...} is generally a better choice. (Static fields are really an advanced feature that I try to have people avoid using unless they're familiar with the details of e.g. how JIT caching interacts with pytree semantics.)
Note that JAX totally will compile properties though. JAX uses a tracing compiler, see point 2 here.