equinox icon indicating copy to clipboard operation
equinox copied to clipboard

How to mark parameters as trainable or not?

Open tomsch420 opened this issue 1 year ago • 9 comments

Greetings!

I got custom Layers in equinox that look approximately like this.

class ProductLayer(InnerLayer):

    child_layers: List[Union[SumLayer, InputLayer]]
    edges: BCOO

class SumLayer(InnerLayer):

    log_weights: List[BCOO]
    child_layers: Union[List[[ProductLayer]], List[InputLayer]]

class ContinuousLayerWithFiniteSupport(ContinuousLayer, ABC):
    interval: jax.Array

I now want to exclude ProductLayer.edges from the parameters of a model since they cannot be adjusted by gradient descent. Fruthermore, SumLayer.log_weights.indices can also not be adjusted. The ContinuousLayerWithFiniteSupport.interval can also not be adjusted using gradient descent. How can i best filter these out for the eqx.partition method?

tomsch420 avatar Sep 28 '24 11:09 tomsch420

See the FAQ: https://docs.kidger.site/equinox/faq/#how-to-mark-arrays-as-non-trainable-like-pytorchs-buffers

patrick-kidger avatar Sep 29 '24 14:09 patrick-kidger

There is a risk to the suggested approach that should at least be highlighted in the docs: the parameters may still be punished by regularization.

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from optax import adamw


class Model(eqx.Module):
    buffer: Array
    param: Array

    def __call__(self, x):
        return self.param * x + jax.lax.stop_gradient(self.buffer)

@eqx.filter_value_and_grad
def loss(model, x):
    return model(x)

model = Model(jnp.ones(()), jnp.ones(()))
loss, grad = loss(model, 2)
optimizer = adamw(1e-1)  # Optimizer with regularization
opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
updates, opt_state = optimizer.update(grad, opt_state, eqx.filter(model, eqx.is_array))
model = eqx.apply_updates(model, updates)
assert model.buffer == jnp.ones(())  # Fails!

Unless I am missing a downside, the approach I think should be recommended is to use a wrapper class (NonTrainable) to wrap non-trainable nodes, and partitioning parameters e.g. with:

params, static = eqx.partition(
        model,
        eqx.is_inexact_array,
        is_leaf=lambda leaf: isinstance(leaf, NonTrainable),
    )

danielward27 avatar Oct 21 '24 16:10 danielward27

Ah! That really isn't very good, you're right.

Hmm, I'm trying to figure out if there's a way to handle this ergonomically. The best I can come up with is to wrap the Optax calls (like we already do for eqx.apply_updates) with something that respects such a Nontrainable wrapper. This is just such an easy footgun!

patrick-kidger avatar Oct 21 '24 18:10 patrick-kidger

FWIW I've landed on the optax wrapper approach. I have a trainable/non_trainable mask that I create early on and partition that way. I don't even bother with stop_grad most of the time and pray that XLA does the DCE for me (it seems to).

For things that are really constants (e.g. rotary embeddings) I just materialize those in the kernel with ensure_compile_time_eval

dlwh avatar Oct 22 '24 20:10 dlwh

Ah, nice! Okay, I think I'm convinced.

I'd be happy to take a PR implementing this, then.

patrick-kidger avatar Oct 22 '24 20:10 patrick-kidger

Just for posterity, @danielward27 has written a small library that fixes this (along with enabling other parameterizations) in https://github.com/danielward27/paramax

smorad avatar Jan 21 '25 03:01 smorad

Oh, this is excellent. Small and does exactly the right thing. @danielward27 would you be interested in having this be advertised in the various Equinox-ecosystem READMEs, e.g. https://github.com/patrick-kidger/equinox/?tab=readme-ov-file#see-also-other-libraries-in-the-jax-ecosystem ?

patrick-kidger avatar Jan 21 '25 11:01 patrick-kidger

Thanks! That would be great to be added to the list! I can do a pull request to add it if you would like - whatever is easiest for you.

danielward27 avatar Jan 22 '25 14:01 danielward27

Awesome :) Yup send a PR! The ecosystem lists appear in README.md and in docs/index.md.

patrick-kidger avatar Jan 22 '25 18:01 patrick-kidger