How to mark parameters as trainable or not?
Greetings!
I got custom Layers in equinox that look approximately like this.
class ProductLayer(InnerLayer):
child_layers: List[Union[SumLayer, InputLayer]]
edges: BCOO
class SumLayer(InnerLayer):
log_weights: List[BCOO]
child_layers: Union[List[[ProductLayer]], List[InputLayer]]
class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC):
interval: jax.Array
I now want to exclude ProductLayer.edges from the parameters of a model since they cannot be adjusted by gradient descent. Fruthermore, SumLayer.log_weights.indices can also not be adjusted. The ContinuousLayerWithFiniteSupport.interval can also not be adjusted using gradient descent. How can i best filter these out for the eqx.partition method?
See the FAQ: https://docs.kidger.site/equinox/faq/#how-to-mark-arrays-as-non-trainable-like-pytorchs-buffers
There is a risk to the suggested approach that should at least be highlighted in the docs: the parameters may still be punished by regularization.
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from optax import adamw
class Model(eqx.Module):
buffer: Array
param: Array
def __call__(self, x):
return self.param * x + jax.lax.stop_gradient(self.buffer)
@eqx.filter_value_and_grad
def loss(model, x):
return model(x)
model = Model(jnp.ones(()), jnp.ones(()))
loss, grad = loss(model, 2)
optimizer = adamw(1e-1) # Optimizer with regularization
opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
updates, opt_state = optimizer.update(grad, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
assert model.buffer == jnp.ones(()) # Fails!
Unless I am missing a downside, the approach I think should be recommended is to use a wrapper class (NonTrainable) to wrap non-trainable nodes, and partitioning parameters e.g. with:
params, static = eqx.partition(
model,
eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, NonTrainable),
)
Ah! That really isn't very good, you're right.
Hmm, I'm trying to figure out if there's a way to handle this ergonomically. The best I can come up with is to wrap the Optax calls (like we already do for eqx.apply_updates) with something that respects such a Nontrainable wrapper. This is just such an easy footgun!
FWIW I've landed on the optax wrapper approach. I have a trainable/non_trainable mask that I create early on and partition that way. I don't even bother with stop_grad most of the time and pray that XLA does the DCE for me (it seems to).
For things that are really constants (e.g. rotary embeddings) I just materialize those in the kernel with ensure_compile_time_eval
Ah, nice! Okay, I think I'm convinced.
I'd be happy to take a PR implementing this, then.
Just for posterity, @danielward27 has written a small library that fixes this (along with enabling other parameterizations) in https://github.com/danielward27/paramax
Oh, this is excellent. Small and does exactly the right thing. @danielward27 would you be interested in having this be advertised in the various Equinox-ecosystem READMEs, e.g. https://github.com/patrick-kidger/equinox/?tab=readme-ov-file#see-also-other-libraries-in-the-jax-ecosystem ?
Thanks! That would be great to be added to the list! I can do a pull request to add it if you would like - whatever is easiest for you.
Awesome :) Yup send a PR! The ecosystem lists appear in README.md and in docs/index.md.