pytorch-lightning
pytorch-lightning copied to clipboard
Use lr setter callback instead of `attr_name` in `LearningRateFinder` and `Tuner`
Description & Motivation
I would like to change the the Tuner
and LearningRateFinder
API so that it is possible to use more custom models.
Description
Currently, the learning rate can only be accessed from the LightningModule
model through an attribute name (either as attribute or within the hyper parameters hparams
). This can be configured through the attr_name
parameter.
However, I would like to replace the attr_name
parameter with a callback lr_setter
to allow advanced access, customization and freedom on where the learning rate is located inside the model.
Motivation
While the current implementation will suit most use cases, it does not fit some advanced usage. Let's say I want to provide the learning rater through a partially instantiated optimizer. This works very well in a hydra / conf setup.
For example, I find the use of partially instantiated optimizer really helpful for tracking experiments, etc. I often use something like:
import functools
from typing import Callable
from torch.optim import Adam, Optimizer
from lightning.pytorch import LightningModule
PartialOptimizer = Callable[..., Optimizer]
class LitModel(LightningModule):
def __init__(self, optimizer: PartialOptimizer) -> None:
super().__init__()
self.save_hyperparameters()
self.model = Model(...)
def configure_optimizer(self) -> Optimizer:
optimizer = self.hparams.get("optimizer")
return optimizer(params=self.parameters())
optimizer = functools.partial(Adam, lr=0.001)
model = LitModel(optimizer)
With this implementation it is not possible to use the LearningRateFinder
callback or Tuner
because the learning rate is not accessible through an attribute or the hyper parameters.
Pitch
I would like to change the parameter attr_name
to lr_setter
(or similar), which could be a function that sets the learning rate given the model. With that functionality it could be possible to use the Tuner
and LearningRateFinder
in more advanced cases, while being compatible with attr-defined learning rate:
The type of the lr_setter
could be a Callable[[pl.LightningModule, float], None]
.
Use case: attr-defined
This is the case that is described currently in the docs
# Using https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html#using-lightning-s-built-in-lr-finder
model = LitModel(learning_rate=0.001)
trainer = Trainer()
tuner = Tuner(trainer)
# Before
lr_finder = tuner.lr_finder(model, attr_name="learning_rate") # Note: attr_name is optional here
# After
from lightning.pytorch.utilities.parsing import lightning_setattr
lr_finder = tuner.lr_finder(model, lr_setter=lambda model, lr: lightning_setattr(model, "learning_rate", lr))
Use case: partial optimizer
This is the case that is not compatible with the current API.
# Using the custom LitModel with partial optimizer instead of attr defined learning rate
optimizer = functools.partial(Adam, lr=0.001)
model = LitModel(optimizer)
trainer = Trainer()
tuner = Tuner(trainer)
# Before
lr_finder = tuner.lr_finder(model, attr_name=???) # Not possible to access the learning rate
# After
def partial_setattr(fn: functools.partial, key: str, value: float) -> None:
*_, (f, args, kwargs, n) = fn.__reduce__()
kwargs[key] = value
fn.__setstate__((f, args, kwargs, n))
lr_finder = tuner.lr_finder(model, lr_setter=lambda model, lr: partial_setattr(model.hparams["optimizer"], "lr", lr)
With this implementation it is not possible to use the LearningRateFinder
callback or Tuner
because the learning rate is not accessible through an attribute or the hyper parameters.
Alternatives
I already implemented a version that achieves exactly that. The update is minimal and core changes will be in the _lr_find
function, from lightning/pytorch/tuner/lr_finder.py
module.
There are only two places to update to make this work:
- Remove / adapt the lines that automatically find the
attr_name
, since this new feature will use alr_setter
function. We could add something similar that automatically generates alr_setter
if none is provided: first check if alr
orlearning_rate
attr is defined and create associated setter, then check if there is a partial optimizer defined namedoptim
oroptimizer
and adapt thelr_setter
. Maybe not necessary, or force the user to specify the setter function. - Call the
lr_setter
instead of using thelightning_setattr
function.
Also this feature requires to rename the attr_name
to lr_setter
to make it obvious that the parameter is a setter.
Additional context
This feature will change the API and is not backward compatible, if the name attr_name
is changed. However, the capacities remain the same, but offer more customization.
I already have a working implementation of this feature, compatible with the latest version of lightning. If this is of interest, I can submit a PR. Thanks, really loving this library!
cc @borda