equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Optax with multiple optimizers

Open ToddMorrill opened this issue 6 months ago • 3 comments

Optax has a function multi_transform, which is nice for using multiple optimizers. See here for a working example.

How should Equinox interact with multi_transform?

Suppose, for instance, I want to freeze the last layer of my network and I want to use optax's set_to_zero function. NB: I'm aware of this example, but it's less general than my question because it only addresses the case where you want to freeze a set of parameters (by excluding them from the differentiable set) and doesn't address the case where we want to apply a different optax optimizer to different sets of parameters.

My code might look something like:

def create_param_labels(model):
    # initialize everything with "train"
    param_labels = jax.tree.map(lambda x: "train", model)
    # set the mask for the last layer's weight to "freeze"
    param_labels = eqx.tree_at(lambda tree: tree.layers[-1].weight, param_mask, replace="freeze")
    return param_labels 
...
model_params, model_structure = eqx.partition(model, eqx.is_inexact_array)
param_labels = create_param_labels(model)
optimizer = optax.multi_transform({"train": optax.sgd(learning_rate=lr), "freeze": optax.set_to_zero()}, param_labels=param_labels)
optimizer_state = optimizer.init(model_params)

This results in an error because param_labels is a PyTree just like the model and the model has a __call__ method implemented, which makes it callable. So then under the hood optax will check if the mask (i.e., the param_labels) is callable. Since it is callable, it ends up calling param_labels(model_params), which isn't what we want.

We can't just delete the __call__ method from the param_labels object because it's a frozen dataclass. So it seems like the only way forward is to make param_labels a callable function that somehow labels PyTree nodes appropriately. I'm not enough of a PyTree pro to know if there's an easy way to manipulate PyTrees to easily target the weight matrix of the last layer. Does anyone have a sense for how to do this?

ToddMorrill avatar Aug 04 '24 18:08 ToddMorrill