equinox
equinox copied to clipboard
Optax with multiple optimizers
Optax has a function multi_transform
, which is nice for using multiple optimizers. See here for a working example.
How should Equinox interact with multi_transform
?
Suppose, for instance, I want to freeze the last layer of my network and I want to use optax's set_to_zero
function. NB: I'm aware of this example, but it's less general than my question because it only addresses the case where you want to freeze a set of parameters (by excluding them from the differentiable set) and doesn't address the case where we want to apply a different optax optimizer to different sets of parameters.
My code might look something like:
def create_param_labels(model):
# initialize everything with "train"
param_labels = jax.tree.map(lambda x: "train", model)
# set the mask for the last layer's weight to "freeze"
param_labels = eqx.tree_at(lambda tree: tree.layers[-1].weight, param_mask, replace="freeze")
return param_labels
...
model_params, model_structure = eqx.partition(model, eqx.is_inexact_array)
param_labels = create_param_labels(model)
optimizer = optax.multi_transform({"train": optax.sgd(learning_rate=lr), "freeze": optax.set_to_zero()}, param_labels=param_labels)
optimizer_state = optimizer.init(model_params)
This results in an error because param_labels
is a PyTree just like the model and the model has a __call__
method implemented, which makes it callable. So then under the hood optax will check if the mask (i.e., the param_labels
) is callable. Since it is callable, it ends up calling param_labels(model_params)
, which isn't what we want.
We can't just delete the __call__
method from the param_labels
object because it's a frozen dataclass. So it seems like the only way forward is to make param_labels
a callable function that somehow labels PyTree nodes appropriately. I'm not enough of a PyTree pro to know if there's an easy way to manipulate PyTrees to easily target the weight matrix of the last layer. Does anyone have a sense for how to do this?