mlx
mlx copied to clipboard
Support LR schedulers
Proposed changes
Added initial implementation of the following learning-rate schedulers (https://github.com/ml-explore/mlx/issues/333):
- StepLR
- ExponentialLR
- MultiStepLR
- LambdaLR
- PolynomialLR
- CosineAnnealingLR
- SequentialLR
They have been adapted from PyTorch implementations, and the API and usage is similar. Would like to discuss if the goal is to attain feature parity with PyTorch in terms of schedulers supported, in which case, I'm willing to implement the rest. Will update the documentation accordingly.
related to #333
Ok this may be a bit out there so let's gather some opinions. @awni and @jagrit06 I obviously would like yours as well.
I think PyTorch's lr schedulers are not particularly well used. For instance, calling step every epoch is not a given since some training runs only have one epoch or a couple. Moreover, this way we can only schedule learning rates but why not schedule weight decay annealing as well? Or any other scalar parameter of an optimizer or a loss?
What do you think about having a ScheduledScalar (or a better name) that implements the __mlx_array__ method so it can be seamlessly used in operations with arrays. And it can be given instead of a scalar for the learning rate.
For example:
learning_rate = CosineAnnealing(start=0.1, end=0.001)
weight_decay = LambdaValue(lambda progress: (1-progress) * 0.09 + 0.01)
extra_loss_weight = LambdaValue(lambda progress: progress * 0.001)
optimizer = SGD(learning_rate=learning_rate, weight_decay=weight_decay)
def loss_fn(model, x, y):
y_hat = model(x)
l1 = main_loss(y_hat, y)
l2 = extra_loss(y_hat, y)
return l1 + extra_loss_weight * l2
loss_and_grad = nn.value_and_grad(model, loss_fn)
for iteration, (x, y) in enumerate(dataset):
[v.update(iteration / (total_iters-1)) for v in [learning_rate, weight_decay, extra_loss_weight]]
loss, grads = loss_and_grad(model, x, y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
WDYT? Feel free to say this is completely overkill and over-engineering and it will only confuse people further.
That reminds me of the way optax does schedules which I actually find pretty nice (and am hoping we will follow something similar) . Basically the learning_rate (and some other parameters) of the optimizer can also be a callable that outputs an scalar.
@angeloskath I don't see the need for __mlx__array__ though. Shouldn't this work if the callable returns a scalar or mx.array?
Edit: I think I understand now. You want to avoid needing to call the callable. I'm not sure there. I find it really nice to just be able to give some arbitrary schedule like lambda it: my +thing +here. But on the other hand we'll have to update parameters which can be callables in the internals of our optimizers etc. I think I prefer the latter cost with the former benefit to vice versa.
@angeloskath if I understand correctly, your idea for a ScheduledScalar is to have it behave as a scalar/array that internally manages its value based on a predefined schedule, i.e., it always evaluates to its current scheduled value at the current stage of the training process (e.g., the current step or epoch), without requiring external updates for each training step.
What I don't quite understand is how it should be aware of the training progress. Is the idea to maintain some global state? Sorry if I'm missing something obvious, I'm new here! 😅
@postmalloc exactly! Regarding the state, you 'd still have to manually update it, see v.update(...) in my pseudo example above. Also imagine you 'd have one thing to do update on so no need for this ugly loop.
But on the other hand we'll have to update parameters which can be callables in the internals of our optimizers etc.
@awni I am not sure I understand the above. Also the main reason is not only to avoid calling the callable but also to move the responsibility of state keeping one layer above, to the trainer. For optax (I may be wrong here as I am an optax noob), the step state is maintained by the optimizer.
In my proposal above the progress is maintained by the training loop. Also we could have any amount of scheduled values updated at once so we don't need to have a loop over them or keep many different step states.
Having said that, I think we need to optimize for ease of use and familiarity so I am totally not against using the lambda approach. In that approach we would have a _step value in the optimizer state that would be passed to the lr or weight decay callbacks right? Also updated every time optimizer.update is called?
Also the main reason is not only to avoid calling the callable but also to move the responsibility of state keeping one layer above, to the trainer.
That makes sense!
For optax (I may be wrong here as I am an optax noob), the step state is maintained by the optimizer.
Correct
I suppose you could still do your suggestion without the __mx_array__ if you make the schedules callable and call them (in the optimizer)? but the __mx_array__ does make it nice.
That said, a couple things give me pause with the having the schedulers so decoupled from the optimizer:
- The need to keep track of the schedulers and wire them through to your training loop (another thing to keep track of and update). The optimizer already has them..
- The API is unfamiliar
LambdaValue(lambda ...)versus justlambda ...which is quite simple. If I want to write a new custom scheduler I have to use the LambdaValue class which is another thing to be aware of.
@angeloskath are there some benefits of decoupling that I'm missing?
Hm not many. Basically 2:
- Keeping track of state in the optimizer is strictly less expressive. Eg one has to update the lr, weight decay etc at every call to
update. Step only goes forward and or it has to be manually reset when loading to some valid value. - It only works for whatever the author of the optimizer (and only the optimizer) had in mind to make dynamic. For instance, does it work for the betas values? Or loss weighting?
The unfamiliar API is the big problem for me. I think my LambdaValue above was a bad first choice as if I only had CosineAnnealing it would feel more familiar.
So in the simplest case it would feel exactly like PyTorch.
lr = CosineAnnealing(start=1e-3, end=1e-5)
optimizer = Adam(learning_rate=lr)
for step, batch in enumerate(data):
...
lr.update(step / total_steps)
It only works for whatever the author of the optimizer (and only the optimizer) had in mind to make dynamic. For instance, does it work for the betas values? Or loss weighting?
This is really nice. But also perhaps orthogonal in the sense that we could have the optimizer parameters work with any instance that has an __mx_array__ (which would be very cool).
Keeping track of state in the optimizer is strictly less expressive
Indeed that's the main point is how much that flexibility is worth. I think I still lean towards keeping the the training loop simple and having the user manage less state (the optax style) vs the PyTorch style with the explicit update/step. It feels like it covers most cases more simply. But perhaps coupled with __mx_array__ we can still have some flexibility for those that want it?
They already work with any object that defines __mlx_array__. Basically the optimizer doesn't need to know anything about them. This would work anywhere we are using a scalar now.
class MyScalar:
def __init__(self, val):
self._val = mx.array(val)
def __mlx_array__(self):
return self._val
Regardless, I will take a closer look at optax and think about how to best manage the step state. I might still come back advocating for a more pytorch-like explicit behavior 😅. I am not a big fan of jax-like hidden behaviour. For instance how would we even log the learning rate. Would we save the step as an array or a scalar, in the optimizer.state or outside etc.
Anyways, I hope I am not wasting everybody's time dealing with a relatively minor matter...
For instance how would we even log the learning rate.
Calling optimizer.learning_rate should give the right learning rate for logging?
Would we save the step as an array or a scalar, in the optimizer.state or outside etc.
I am actually not against the idea of saving the step as a scalar array in the optimizer state. It makes it really easy to checkpoint training runs and just resume where you left off. This is something I find to be relatively simple with optax. In that sense the step really belongs in the optimizer state.
Anyways, I hope I am not wasting everybody's time dealing with a relatively minor matter...
I'm enjoying the discussion!
I find the discussion very interesting too! I was playing around with the idea that @awni suggested. In the code, we could broadly make the following two changes: 1) introduce a ParameterScheduler class which allows one to define a schedule for any parameter as a callable. 2) modify the Optimizer class to allow it to track progress.
It looks something like this (omitted some code for brevity) :
class ParameterScheduler:
def __init__(self, schedule_fn):
self.schedule_fn = schedule_fn
def get_value(self, step):
return self.schedule_fn(step)
class Optimizer:
def __init__(self, schedulers=None):
self.state = {'step': 0}
self.schedulers = schedulers or {}
def get_scheduled_value(self, param_name):
step = self.state['step']
scheduler = self.schedulers.get(param_name)
return scheduler.get_value(step) if scheduler else None
def update(self, model, gradients):
self.state['step'] += 1
model.update(self.apply_gradients(gradients, model))
class SGD(Optimizer):
def __init__(self, learning_rate, momentum, weight_decay, schedulers=None):
super().__init__(schedulers=schedulers)
def apply_single(self, gradient, parameter, state):
self.learning_rate = self.get_scheduled_value('learning_rate') or self.learning_rate
self.weight_decay = self.get_scheduled_value('weight_decay') or self.weight_decay
self.momentum = self.get_scheduled_value('momentum') or self.momentum
...
and in the training loop, we pass the callbacks to the optimizer:
total_steps = 1000
lr_scheduler = ParameterScheduler(lambda step: 0.1 - 0.09 * (step / total_steps))
wd_scheduler = ParameterScheduler(lambda step: 0.01 * (1 - step / total_steps))
optimizer = SGD(learning_rate=0.1, weight_decay=0.01,
schedulers={'learning_rate': lr_scheduler, 'weight_decay': wd_scheduler})
@awni, is this along the lines of what you're thinking? I'm not sure if it might become a bit unwieldy to define ParameterScheduler instead of a plain lambda function.
Thanks for the revision. This is a bit inbetween what I was suggesting and what @angeloskath was suggesting. I think if we go with the route of explicit schedule classes (not callables) then they should expose __mx_array__.
What I had in mind is more like this:
class Optimizer:
def __init__(self, schedulers=None):
self.state = {'step': 0}
self.schedulers = schedulers or {}
def update(self, model, gradients):
self.state['step'] += 1
model.update(self.apply_gradients(gradients, model))
@property
def step(self):
return self.state["step"]
class SGD(Optimizer):
def __init__(self, learning_rate, momentum, weight_decay):
super().__init__(schedulers=schedulers)
# maybe provide a helper to reduce boiler plate here
if isinstance(learning_rate, float):
self.learning_rate_ = lambda _ : learning_rate
else:
self.learning_rate_ = learning_rate
def apply_single(self, gradient, parameter, state):
p = p - self.learning_rate * g
...
@property
def learning_rate(self):
return self.learning_rate_(self.step)
Use it like this:
optimizer = SGD(learning_rate=lambda step: step ** 0.99)
No more need to manage the scheduler once you pass it to the optimizer..
@postmalloc are you still working on this? Just curious what the plan is as there have been some requests for schedulers :)
@awni I wasn't actually sure if a consensus had been reached between the approach you suggested and what @angeloskath did 😄. I do like callback version you proposed, and I'd be happy to implement it if that's the direction we shall take
I wasn't actually sure if a consensus had been reached between the approach you suggested and what
Makes sense
I think it will be easier to criticize the pros/cons of the approach if we have a more concrete implementation (you don't have to do a lot of schedulers, just one or two + the scaffolding). Are you be open to trying the version I suggested? If it ends up having some sticky issues we can revisit?
Hi @postmalloc checking in on this. Are you still working on the PR?
Hi @postmalloc are you planning to work on this PR at all? If not let's close it so we can let someone else tackle / work on schedulers.
Hi @awni, sorry for disappearing! I took your suggestion, expanded it a bit. It should allow one to pass callables for arbitrary hyper-parameters:
# a helper function to wrap params as callable schedulers
def ensure_scheduler(parameter):
if callable(parameter):
return parameter
else:
return lambda _: parameter
class Optimizer:
def __init__(self, schedulers=None):
self.state = {'step': 0}
self.schedulers = {k: ensure_scheduler(v) for k, v in (schedulers or {}).items()}
def update(self, model, gradients):
self.state['step'] += 1
self.update_scheduled_params()
model.update(self.apply_gradients(gradients, model))
@property
def step(self):
return self.state["step"]
def update_scheduled_params(self):
for param, scheduler in self.schedulers.items():
if hasattr(self, param):
setattr(self, param, scheduler(self.step))
def exponential_decay_scheduler(initial_value, decay_rate, decay_steps):
return lambda step: initial_value * (decay_rate ** (step / decay_steps))
def step_decay_scheduler(initial_value, drop_rate, step_size):
return lambda step: initial_value * (drop_rate ** (step // step_size))
and we use it like this:
lr_scheduler = exponential_decay_scheduler(initial_value=0.1, decay_rate=0.96, decay_steps=1000)
momentum_scheduler = step_decay_scheduler(initial_value=0.9, drop_rate=0.5, step_size=5000)
optimizer = SGD(learning_rate=lr_scheduler, momentum=momentum_scheduler, weight_decay=0.01)
I can flesh it out add tests if it looks okay to you.
The part you shared looks nice, it's pretty simple. One of our optimizers (AdaFactor) already has the step as state, so we'd need to refactor that into the base class.
Can you show how it would look in a training loop also?
CC @angeloskath
@postmalloc I actually like this design.
In order for that to work we need to change eval to work with arbitrary dictionaries that can contain non-array leaves. This is also true for Adafactor which currently would throw if someone evaluated the state. Also the state needs to be an OptimizerState rather than a dictionary.
The other option is to make the step a scalar array. This might come with some performance overhead though so we should test it before doing that. If there is no performance overhead at all, then the latter is quite nice as the optimizer state will also be savable with mx.save.
The other option is to make the step a scalar array.
I like that option, makes saving easier + no changes to eval. Generally still probably ok if you forget to eval the optimizer state.
@postmalloc any updates on this?
This is a great idea and another thing I wouldn't have to roll my own version of. The only thing I would add is a request for SGDR (see cyclic-cosine-decay) or at least a framework to easily add it to this.
@angeloskath I finished this PR. Can you take a look?
Basically finished @postmalloc suggested implementation with some modifications (mostly around making sure everything fits with compile).
The usage is pretty simple:
scheduler = optim.cosine_decay(...)
opt = optim.SGD(learning_rate=scheduler)
Also adds an array step to the optimizers state which needed to happen anyway to unbreak adafactor with compile (CC @davidkoski ).
Fixed a type annotation for python 3.8. I wonder 1) if we should be running the tests on the older python version we support (probably yes) 2) if we should support python 3.9+.
Honestly I prefer just to support 3.9+. But I'm a little bit on the ruthless side when it comes to back compatibility. I don't like how it adds complexity for very minor benefits..
But since you're the only 3.8 user I know, maybe you should decide 😛