equinox
equinox copied to clipboard
Using Frozen Parameters to Freeze Differing Parameters Based Upon Epoch
I am trying to figure out if it is possible to efficiently use something like the frozen parameters approach described here https://docs.kidger.site/equinox/examples/frozen_layer/ , to setup up a training loop such that a different subset of parameters are frozen based upon the epoch
e.g For simplicity's sake, let's say you have an MLP, and you want the weights of odd and even layers to alternate between being updated each epoch (this isn't my real use case). As far as I can tell, if you use the filter_spec approach, the opt_state will have None values for the frozen parameters, so you will then run into issues when you apply the opposite filter. Am I missing something here?
Hi,
I work on something similar. For your example, what I can think of is to create two optimizers. Each optimizer transforms and update differently with different learning rates. This can be done by customizing learning rate schedule. Then we need to filter even and odd layer parameters. Finally, we can use optax.multi_transform
(see docs) to select which optimizers are used.
Here is a small example
class Model(eqx.Module):
even_layer: eqx.Module
odd_layer: eqx.Module
def create_lr_scheduler(lr: float, is_even: bool, num_steps_per_epoch: int):
def lr_scheduler(num_steps):
current_epoch = num_steps // num_steps_per_epoch
# return learning rate either 0 or the given learning rate based on epoch
return lr_scheduler
even_layer_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=True, num_steps_per_epoch=1000))
odd_layer_optim = optax.adam(learnning_rate=create_lr_scheduler(lr, is_even=False, num_steps_per_epoch=1000))
# assume that you model is initiated
model = Model()
# first make all parameters with `even` labels
even_labels = jax.tree_utils.tree_map(lambda _: "even", model)
def get_odd_layer(tree: Model):
# return all the odd layers. might be more complicated in your model
return tree.odd_layer
# replace some odd params with new label
final_labels = eqx.tree_at(
where=get_odd_layer, # find where is the odd layer
pytree=even_labels,
replace_fn= lambda _: "odd" # and replace with new labels
)
# the final optimizer will select specific optimizers based on the label
optim = optax.multi_transform(
transforms={
"even": even_layer_optim,
"odd": odd_layer_optim,
},
param_labels=final_labels
)
Another solution is to use optax.chain
with optax.masked
(see this) where we can mask on/off some of the parameters. But I find optax.multi_transform
is more convenient where we can define many labels to do transformation.
Thank you for the advice. I will give them a try.
Edit : How is "is_even" tested? Obviously, the labels are set to odd or even, how does the optimiser know "is_even" is true?
Edit - Silly question, with this multi_transform approach, how do you setup the training loop?
e.g. Normally with Equinox "model" you use a setup something like this,
opt_init, opt_update = optax.adam(config.learning_rate)
opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))
...
loss, grads = loss_func(model, x, y)
updates, update_opt_state = opt_update(grads, opt_state)
update_model = eqx.apply_updates(model, updates)
with the following, I get a "TypeError: Shapes must be 1D sequences of concrete values of integer type, got (None, 64)." upon the call to opt_init(...)
opt_init, opt_update = optax.multi_transform(
transforms={
"even": even_layer_optim,
"odd": odd_layer_optim,
},
param_labels=final_labels
)
opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))
Uhm, I don't know why it throws TypeError
here. Would you mind provide some more details (e.g., code, error trace)?
As far as I know, there is no further setup, and the training loop is the same with normal Equinox. I elaborate more in the example code here
import equinox as eqx
import optax
import jax
import jax.tree_util
class Model(eqx.Module):
even_layer: eqx.Module
odd_layer: eqx.Module
def create_lr_scheduler(lr: float, is_even: bool, num_steps_per_epoch: int):
def lr_scheduler(num_steps):
current_epoch = num_steps // num_steps_per_epoch
# return learning rate either 0 or the given learning rate based on epoch
return lr # a dummy return
return lr_scheduler
lr = 1e-3
even_layer_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=True, num_steps_per_epoch=1000))
odd_layer_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=False, num_steps_per_epoch=1000))
# assume that you model is initiated
model = Model(odd_layer=eqx.nn.Linear(3, 3, key=jax.random.PRNGKey(0)),
even_layer=eqx.nn.Linear(3, 1, key=jax.random.PRNGKey(1)))
# first make all parameters with `even` labels
even_labels = jax.tree_util.tree_map(lambda _: "even", model)
def get_odd_layer(tree: Model):
# return all the odd layers. might be more complicated in your model
return tree.odd_layer
# replace some odd params with new label
final_labels = eqx.tree_at(
where=get_odd_layer, # find where is the odd layer
pytree=even_labels,
replace_fn= lambda _: "odd" # and replace with new labels
)
# the final optimizer will select specific optimizers based on the label
optim_init, optim_update = optax.multi_transform(
transforms={
"even": even_layer_optim,
"odd": odd_layer_optim,
},
param_labels=final_labels
)
optim_init(eqx.filter(model, eqx.is_inexact_array))
Here, optim_init
works fine.
Edit : How is "is_even" tested? Obviously, the labels are set to odd or even, how does the optimiser know "is_even" is true?
You can print out the final_labels
to check how Pytree look likes. Unlike freezing layer where the leaves of freezed nodes in Pytree are None
, the leaves of final_labels
will be assigned with either label even
or label odd
. I guess that multi_transform
using the the keys of the dictionary to select optimizers
transforms={
"even": even_layer_optim,
"odd": odd_layer_optim,}
and final_labels
will provide the information of parameters that helps optimizers find the coresponding parameters.
OK, so I believe I have found the issue, but no idea why it should be a problem. Any attempt to override the __ call __ method results in various errors depending upon exactly what is being called from it. Define exactly the same code from a named method, no errors.
Below is an (silly) example. I wonder if this is an equinox problem or an optax problem?
class Model(eqx.Module):
odd_mlp: eqx.nn.MLP
even_mlp: eqx.nn.MLP
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.odd_mlp = eqx.nn.MLP(in_size=2, out_size=2, width_size=32, depth=3, key=jax.random.PRNGKey(0))
self.even_mlp = eqx.nn.MLP(in_size=2, out_size=2, width_size=32, depth=3, key=jax.random.PRNGKey(1))
def calc_mlp(self, x1):
return vmap(self.odd_mlp)(x1)
# INCLUDE THE CALL OVERRIDE, GET ERROR
# def __call__(self, x1):
# return vmap(self.odd_mlp)(x1)
def create_lr_scheduler(lr: float, is_even: bool, num_steps_per_epoch: int):
def lr_scheduler(num_steps):
current_epoch = num_steps // num_steps_per_epoch
# return learning rate either 0 or the given learning rate based on epoch
return lr # a dummy return
return lr_scheduler
lr = 1e-3
even_mlp_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=True, num_steps_per_epoch=1000))
odd_mlp_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=False, num_steps_per_epoch=1000))
# assume that you model is initiated
model = Model()
# first make all parameters with `even` labels
even_labels = jax.tree_util.tree_map(lambda _: "even", model)
def get_odd_mlp(tree: Model):
# return all the odd layers. might be more complicated in your model
return tree.odd_mlp
# replace some odd params with new label
final_labels = eqx.tree_at(
where=get_odd_mlp, # find where is the odd layer
pytree=even_labels,
replace_fn=lambda _: "odd" # and replace with new labels
)
# the final optimizer will select specific optimizers based on the label
optim_init, optim_update = optax.multi_transform(
transforms={
"even": even_mlp_optim,
"odd": odd_mlp_optim,
},
param_labels=final_labels
)
optim_init(eqx.filter(model, eqx.is_inexact_array))
Oh, I'm sorry. I encountered the problem before but don't recall.
A solution here is to use different method name instead of __call__
. I don't know any workaround (maybe optax.chain
together with optax.masked
can help, or passing param_labels
as a function to decide the labels of leaves).
optax.multi_transform
takes param_labels
with Union[Any, Callable[[Any],Any]]
. During validation, it check if our passing value 'final_labels' is callable. If the Equinox module implement __call__
, it recognizes final_labels
as callable. However, what we intend is to pass a Pytree with labels.
Thanks for all your help, I now appear to have things running.
A couple of final questions, I am not sure I follow this part of the code. Why would you ever be returning a LR of 0?
def lr_scheduler(num_steps):
current_epoch = num_steps // num_steps_per_epoch
# return learning rate either 0 or the given learning rate based on epoch
return lr # a dummy return
And is it possible to attach a different loss function for the odd / even optimiser? (or if in the loss function, there are numerous losses, different weighting of those that are normally passed as arguments to the loss function)
Overriding __call__
: this is a known problem with Optax; the reason is as @anh-tong describes.
If you want to keep the __call__
method than a workaround is simply to wrap your model into a length-1 list; see #193 for an example of this issue in isolation.
As for the original problem of using different frozen parameters: probably the easiest approach is to pass in a mask indicating which parameters you want to update at each step, e.g.:
def loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jnp.mean((y - pred_y)**2)
@eqx.filter_jit
def make_step(model, opt_state, gradient_mask, x, y):
grads = eqx.filter_grad(loss)(model, x, y)
assert jtu.tree_structure(grads) == jtu.tree_structure(gradient_mask)
grads = jtu.tree_map(lambda g, b: jnp.where(b, g, jnp.zeros_like(g)), grads, gradient_mask)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return model, opt_state
here gradient_mask
should be a PyTree with the same structure as model
, with a True/False
leaf for every floating-point array indicating whether that array should be updated with a gradient, and with None
for all leaves that aren't floating-point arrays.
Why would you ever be returning a LR of 0?
A quick implement can be
def create_lr_scheduler(lr: float, is_even: bool, num_steps_per_epoch: int):
def lr_scheduler(num_steps):
current_epoch = num_steps // num_steps_per_epoch
if is_even:
return lr if current_epoch % 2 == 0 else 0.
else:
return lr if current_epoch % 2 == 1 else 0.
return lr_scheduler
or if in the loss function, there are numerous losses, different weighting of those that are normally passed as arguments to the loss function
Possibly, I think you can create an additional parameters in your Equinox module to work with the weight of losses (with masked/labeled the weights). Then, just optimize the final weighted loss.
here
gradient_mask
should be a PyTree with the same structure asmodel
, with aTrue/False
leaf for every floating-point array indicating whether that array should be updated with a gradient, and withNone
for all leaves that aren't floating-point arrays.
Thanks @patrick-kidger for further clarification. gradient_mask
seems very elegant.
The code above results in a ConcretizationTypeError because an abstract tracer value was encountered where a concrete value is expected (due to the use of if / else).
I think the gradient_mask approach is perfect for my particular use case.
Oh right, there is ConcretizationTypeError
. So then, gradient_mask
is the right approach.