pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Use lr setter callback instead of `attr_name` in `LearningRateFinder` and `Tuner`

Open arthurdjn opened this issue 4 weeks ago • 4 comments

Description & Motivation

I would like to change the the Tuner and LearningRateFinder API so that it is possible to use more custom models.

Description

Currently, the learning rate can only be accessed from the LightningModule model through an attribute name (either as attribute or within the hyper parameters hparams). This can be configured through the attr_name parameter.

However, I would like to replace the attr_name parameter with a callback lr_setter to allow advanced access, customization and freedom on where the learning rate is located inside the model.

Motivation

While the current implementation will suit most use cases, it does not fit some advanced usage. Let's say I want to provide the learning rater through a partially instantiated optimizer. This works very well in a hydra / conf setup.

For example, I find the use of partially instantiated optimizer really helpful for tracking experiments, etc. I often use something like:

import functools
from typing import Callable

from torch.optim import Adam, Optimizer
from lightning.pytorch import LightningModule


PartialOptimizer = Callable[..., Optimizer]

class LitModel(LightningModule):
    def __init__(self, optimizer: PartialOptimizer) -> None:
        super().__init__()
        self.save_hyperparameters()

        self.model = Model(...)

    def configure_optimizer(self) -> Optimizer:
        optimizer = self.hparams.get("optimizer")
        return optimizer(params=self.parameters())


optimizer = functools.partial(Adam, lr=0.001)
model = LitModel(optimizer)

With this implementation it is not possible to use the LearningRateFinder callback or Tuner because the learning rate is not accessible through an attribute or the hyper parameters.

Pitch

I would like to change the parameter attr_name to lr_setter (or similar), which could be a function that sets the learning rate given the model. With that functionality it could be possible to use the Tuner and LearningRateFinder in more advanced cases, while being compatible with attr-defined learning rate:

The type of the lr_setter could be a Callable[[pl.LightningModule, float], None].

Use case: attr-defined

This is the case that is described currently in the docs

# Using https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html#using-lightning-s-built-in-lr-finder
model = LitModel(learning_rate=0.001)

trainer = Trainer()
tuner = Tuner(trainer)


# Before
lr_finder = tuner.lr_finder(model, attr_name="learning_rate")  # Note: attr_name is optional here

# After
from lightning.pytorch.utilities.parsing import lightning_setattr

lr_finder = tuner.lr_finder(model, lr_setter=lambda model, lr: lightning_setattr(model, "learning_rate", lr))

Use case: partial optimizer

This is the case that is not compatible with the current API.

# Using the custom LitModel with partial optimizer instead of attr defined learning rate
optimizer = functools.partial(Adam, lr=0.001)
model = LitModel(optimizer)

trainer = Trainer()
tuner = Tuner(trainer)


# Before
lr_finder = tuner.lr_finder(model, attr_name=???) # Not possible to access the learning rate

# After
def partial_setattr(fn: functools.partial, key: str, value: float) -> None:
    *_, (f, args, kwargs, n) = fn.__reduce__()
    kwargs[key] = value
    fn.__setstate__((f, args, kwargs, n))

lr_finder = tuner.lr_finder(model, lr_setter=lambda model, lr: partial_setattr(model.hparams["optimizer"], "lr", lr)

With this implementation it is not possible to use the LearningRateFinder callback or Tuner because the learning rate is not accessible through an attribute or the hyper parameters.

Alternatives

I already implemented a version that achieves exactly that. The update is minimal and core changes will be in the _lr_find function, from lightning/pytorch/tuner/lr_finder.py module.

There are only two places to update to make this work:

  1. Remove / adapt the lines that automatically find the attr_name, since this new feature will use a lr_setter function. We could add something similar that automatically generates a lr_setter if none is provided: first check if a lr or learning_rate attr is defined and create associated setter, then check if there is a partial optimizer defined named optim or optimizer and adapt the lr_setter. Maybe not necessary, or force the user to specify the setter function.
  2. Call the lr_setter instead of using the lightning_setattr function.

Also this feature requires to rename the attr_name to lr_setter to make it obvious that the parameter is a setter.

Additional context

This feature will change the API and is not backward compatible, if the name attr_name is changed. However, the capacities remain the same, but offer more customization.

I already have a working implementation of this feature, compatible with the latest version of lightning. If this is of interest, I can submit a PR. Thanks, really loving this library!

cc @borda

arthurdjn avatar Jun 05 '24 13:06 arthurdjn