equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Using Frozen Parameters to Freeze Differing Parameters Based Upon Epoch

Open adam-hartshorne opened this issue 1 year ago • 12 comments

I am trying to figure out if it is possible to efficiently use something like the frozen parameters approach described here https://docs.kidger.site/equinox/examples/frozen_layer/ , to setup up a training loop such that a different subset of parameters are frozen based upon the epoch

e.g For simplicity's sake, let's say you have an MLP, and you want the weights of odd and even layers to alternate between being updated each epoch (this isn't my real use case). As far as I can tell, if you use the filter_spec approach, the opt_state will have None values for the frozen parameters, so you will then run into issues when you apply the opposite filter. Am I missing something here?

adam-hartshorne avatar Feb 27 '23 01:02 adam-hartshorne

Hi,

I work on something similar. For your example, what I can think of is to create two optimizers. Each optimizer transforms and update differently with different learning rates. This can be done by customizing learning rate schedule. Then we need to filter even and odd layer parameters. Finally, we can use optax.multi_transform (see docs) to select which optimizers are used.

Here is a small example

class Model(eqx.Module):
    
    even_layer: eqx.Module
    odd_layer: eqx.Module

def create_lr_scheduler(lr: float, is_even: bool, num_steps_per_epoch: int):
    
    def lr_scheduler(num_steps):
        current_epoch = num_steps // num_steps_per_epoch
        # return learning rate either 0 or the given learning rate based on epoch 
    
    return lr_scheduler

even_layer_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=True, num_steps_per_epoch=1000))
odd_layer_optim = optax.adam(learnning_rate=create_lr_scheduler(lr, is_even=False, num_steps_per_epoch=1000))

# assume that you model is initiated
model = Model()

# first make all parameters with `even` labels
even_labels = jax.tree_utils.tree_map(lambda _: "even", model)

def get_odd_layer(tree: Model):
    # return all the odd layers. might be more complicated in your model
    return tree.odd_layer
    
# replace some odd params with new label
final_labels = eqx.tree_at(
    where=get_odd_layer,                # find where is the odd layer
    pytree=even_labels,
    replace_fn= lambda _: "odd"         # and replace with new labels
)

# the final optimizer will select specific optimizers based on the label
optim = optax.multi_transform(
    transforms={
        "even": even_layer_optim,
        "odd": odd_layer_optim,
    },
    param_labels=final_labels
)

Another solution is to use optax.chain with optax.masked (see this) where we can mask on/off some of the parameters. But I find optax.multi_transform is more convenient where we can define many labels to do transformation.

anh-tong avatar Feb 27 '23 05:02 anh-tong

Thank you for the advice. I will give them a try.

Edit : How is "is_even" tested? Obviously, the labels are set to odd or even, how does the optimiser know "is_even" is true?

adam-hartshorne avatar Feb 27 '23 11:02 adam-hartshorne

Edit - Silly question, with this multi_transform approach, how do you setup the training loop?

e.g. Normally with Equinox "model" you use a setup something like this,

opt_init, opt_update = optax.adam(config.learning_rate)
opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))
...
loss, grads = loss_func(model, x, y)
updates, update_opt_state = opt_update(grads, opt_state)
update_model = eqx.apply_updates(model, updates)

with the following, I get a "TypeError: Shapes must be 1D sequences of concrete values of integer type, got (None, 64)." upon the call to opt_init(...)

opt_init, opt_update = optax.multi_transform(
    transforms={
        "even": even_layer_optim,
        "odd": odd_layer_optim,
    },
    param_labels=final_labels
)

opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))

adam-hartshorne avatar Feb 27 '23 11:02 adam-hartshorne

Uhm, I don't know why it throws TypeError here. Would you mind provide some more details (e.g., code, error trace)?

As far as I know, there is no further setup, and the training loop is the same with normal Equinox. I elaborate more in the example code here

import equinox as eqx
import optax
import jax
import jax.tree_util


class Model(eqx.Module):
    
    even_layer: eqx.Module
    odd_layer: eqx.Module

def create_lr_scheduler(lr: float, is_even: bool, num_steps_per_epoch: int):
    
    def lr_scheduler(num_steps):
        current_epoch = num_steps // num_steps_per_epoch
        # return learning rate either 0 or the given learning rate based on epoch 
        return lr  # a dummy return
        
    
    return lr_scheduler

lr = 1e-3
even_layer_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=True, num_steps_per_epoch=1000))
odd_layer_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=False, num_steps_per_epoch=1000))

# assume that you model is initiated
model = Model(odd_layer=eqx.nn.Linear(3, 3, key=jax.random.PRNGKey(0)),
              even_layer=eqx.nn.Linear(3, 1, key=jax.random.PRNGKey(1)))

# first make all parameters with `even` labels
even_labels = jax.tree_util.tree_map(lambda _: "even", model)

def get_odd_layer(tree: Model):
    # return all the odd layers. might be more complicated in your model
    return tree.odd_layer
    
# replace some odd params with new label
final_labels = eqx.tree_at(
    where=get_odd_layer,                # find where is the odd layer
    pytree=even_labels,
    replace_fn= lambda _: "odd"         # and replace with new labels
)

# the final optimizer will select specific optimizers based on the label
optim_init, optim_update = optax.multi_transform(
    transforms={
        "even": even_layer_optim,
        "odd": odd_layer_optim,
    },
    param_labels=final_labels
)

optim_init(eqx.filter(model, eqx.is_inexact_array))

Here, optim_init works fine.

anh-tong avatar Feb 27 '23 12:02 anh-tong

Edit : How is "is_even" tested? Obviously, the labels are set to odd or even, how does the optimiser know "is_even" is true?

You can print out the final_labels to check how Pytree look likes. Unlike freezing layer where the leaves of freezed nodes in Pytree are None, the leaves of final_labels will be assigned with either label even or label odd. I guess that multi_transform using the the keys of the dictionary to select optimizers

transforms={
        "even": even_layer_optim,
        "odd": odd_layer_optim,}

and final_labels will provide the information of parameters that helps optimizers find the coresponding parameters.

anh-tong avatar Feb 27 '23 12:02 anh-tong

OK, so I believe I have found the issue, but no idea why it should be a problem. Any attempt to override the __ call __ method results in various errors depending upon exactly what is being called from it. Define exactly the same code from a named method, no errors.

Below is an (silly) example. I wonder if this is an equinox problem or an optax problem?

class Model(eqx.Module):
    odd_mlp: eqx.nn.MLP
    even_mlp: eqx.nn.MLP

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.odd_mlp = eqx.nn.MLP(in_size=2, out_size=2, width_size=32, depth=3, key=jax.random.PRNGKey(0))
        self.even_mlp = eqx.nn.MLP(in_size=2, out_size=2, width_size=32, depth=3, key=jax.random.PRNGKey(1))

    def calc_mlp(self, x1):
        return vmap(self.odd_mlp)(x1)

    # INCLUDE THE CALL OVERRIDE, GET ERROR
    # def __call__(self, x1):
    #     return vmap(self.odd_mlp)(x1)

def create_lr_scheduler(lr: float, is_even: bool, num_steps_per_epoch: int):
    def lr_scheduler(num_steps):
        current_epoch = num_steps // num_steps_per_epoch
        # return learning rate either 0 or the given learning rate based on epoch
        return lr  # a dummy return

    return lr_scheduler

lr = 1e-3
even_mlp_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=True, num_steps_per_epoch=1000))
odd_mlp_optim = optax.adam(learning_rate=create_lr_scheduler(lr, is_even=False, num_steps_per_epoch=1000))


# assume that you model is initiated
model = Model()

# first make all parameters with `even` labels
even_labels = jax.tree_util.tree_map(lambda _: "even", model)


def get_odd_mlp(tree: Model):
    # return all the odd layers. might be more complicated in your model
    return tree.odd_mlp


# replace some odd params with new label
final_labels = eqx.tree_at(
    where=get_odd_mlp,  # find where is the odd layer
    pytree=even_labels,
    replace_fn=lambda _: "odd"  # and replace with new labels
)

# the final optimizer will select specific optimizers based on the label
optim_init, optim_update = optax.multi_transform(
    transforms={
        "even": even_mlp_optim,
        "odd": odd_mlp_optim,
    },
    param_labels=final_labels
)

optim_init(eqx.filter(model, eqx.is_inexact_array))

adam-hartshorne avatar Feb 27 '23 23:02 adam-hartshorne

Oh, I'm sorry. I encountered the problem before but don't recall.

A solution here is to use different method name instead of __call__. I don't know any workaround (maybe optax.chain together with optax.masked can help, or passing param_labels as a function to decide the labels of leaves).

optax.multi_transform takes param_labels with Union[Any, Callable[[Any],Any]]. During validation, it check if our passing value 'final_labels' is callable. If the Equinox module implement __call__, it recognizes final_labels as callable. However, what we intend is to pass a Pytree with labels.

anh-tong avatar Feb 28 '23 02:02 anh-tong

Thanks for all your help, I now appear to have things running.

A couple of final questions, I am not sure I follow this part of the code. Why would you ever be returning a LR of 0?

 def lr_scheduler(num_steps):
        current_epoch = num_steps // num_steps_per_epoch
        # return learning rate either 0 or the given learning rate based on epoch 
        return lr  # a dummy return

And is it possible to attach a different loss function for the odd / even optimiser? (or if in the loss function, there are numerous losses, different weighting of those that are normally passed as arguments to the loss function)

adam-hartshorne avatar Feb 28 '23 02:02 adam-hartshorne

Overriding __call__: this is a known problem with Optax; the reason is as @anh-tong describes.

If you want to keep the __call__ method than a workaround is simply to wrap your model into a length-1 list; see #193 for an example of this issue in isolation.

As for the original problem of using different frozen parameters: probably the easiest approach is to pass in a mask indicating which parameters you want to update at each step, e.g.:

def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y)**2)

@eqx.filter_jit
def make_step(model, opt_state, gradient_mask, x, y):
    grads = eqx.filter_grad(loss)(model, x, y)
    assert jtu.tree_structure(grads) == jtu.tree_structure(gradient_mask)
    grads = jtu.tree_map(lambda g, b: jnp.where(b, g, jnp.zeros_like(g)), grads, gradient_mask)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state

here gradient_mask should be a PyTree with the same structure as model, with a True/False leaf for every floating-point array indicating whether that array should be updated with a gradient, and with None for all leaves that aren't floating-point arrays.

patrick-kidger avatar Feb 28 '23 03:02 patrick-kidger

Why would you ever be returning a LR of 0?

A quick implement can be

def create_lr_scheduler(lr: float, is_even: bool, num_steps_per_epoch: int):
    
    def lr_scheduler(num_steps):
        current_epoch = num_steps // num_steps_per_epoch
        if is_even:
            return lr if current_epoch % 2 == 0 else 0.
        else:
            return lr if current_epoch % 2 == 1 else 0.
        
    return lr_scheduler

or if in the loss function, there are numerous losses, different weighting of those that are normally passed as arguments to the loss function

Possibly, I think you can create an additional parameters in your Equinox module to work with the weight of losses (with masked/labeled the weights). Then, just optimize the final weighted loss.

here gradient_mask should be a PyTree with the same structure as model, with a True/False leaf for every floating-point array indicating whether that array should be updated with a gradient, and with None for all leaves that aren't floating-point arrays.

Thanks @patrick-kidger for further clarification. gradient_mask seems very elegant.

anh-tong avatar Feb 28 '23 03:02 anh-tong

The code above results in a ConcretizationTypeError because an abstract tracer value was encountered where a concrete value is expected (due to the use of if / else).

I think the gradient_mask approach is perfect for my particular use case.

adam-hartshorne avatar Feb 28 '23 03:02 adam-hartshorne

Oh right, there is ConcretizationTypeError. So then, gradient_mask is the right approach.

anh-tong avatar Feb 28 '23 09:02 anh-tong