pymc icon indicating copy to clipboard operation
pymc copied to clipboard

ENH: Implement learning rate scheduling for variational inference

Open alvaropp opened this issue 2 years ago • 7 comments

Before

I’ve started using PyMC’s variational inference to fit my models. I have no prior experience with this, but I guess there’s parallels with training neural networks, where the choice of optimiser and learning rate have a huge impact on training quality and speed. A common technique for training neural networks is using learning rate schedulers which reduce the learning rate on a schedule, to get faster convergence by starting high and reducing it in successive epochs where you want to be more precise.

Currently, in PyMC, you need to specify a suitable learning rate that is used for fitting the model. Too large and it won't converge, too small and it will be too slow. Training once with a large-ish learning rate and then taking the results of that training round as a starting point for another training round with a smaller learning rate is not trivial and not very elegant.

After

You use a callback to dynamically reduce the optimiser's learning rate if the loss has stagnated and keep training:

learning_rate = pytensor.shared(1e-1, "learning_rate")
optimiser = pm.adam(learning_rate=learning_rate)

fit = pm.fit(
    obj_optimizer=optimiser,
    callbacks=[
        LearningRateScheduler(
            initial_learning_rate=learning_rate,
            factor=0.1,
            patience=10,
            min_lr=1e-4,
            cooldown=10,
            verbose=False,
        ),
    ],
)

Context for the issue:

No response

alvaropp avatar Nov 14 '23 13:11 alvaropp

Welcome Banner :tada: Welcome to PyMC! :tada: We're really excited to have your input into the project! :sparkling_heart:
If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

welcome[bot] avatar Nov 14 '23 13:11 welcome[bot]

I've been discussing this with @jessegrabowski in the PyMC forum: https://discourse.pymc.io/t/learning-rate-scheduling-for-variational-inference/13075/1

alvaropp avatar Nov 14 '23 13:11 alvaropp

I suggest we follow Keras' ReduceLROnPlateau and implement this learning rate scheduling as a callback as follows:

import numpy as np
from pymc.variational.callbacks import Callback

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when the loss has stopped improving.

    This is inspired by Keras' homonymous callback:
    https://github.com/keras-team/keras/blob/v2.14.0/keras/callbacks.py

    Parameters
    ----------
    learning_rate: pytensor.shared
        shared variable containing the learning rate
    factor: float
        factor by which the learning rate will be reduced: `new_lr = lr * factor`
    patience: int
        number of epochs with no improvement after which learning rate will be reduced
    min_lr: float
        lower bound on the learning rate
    cooldown: int
        number of iterations to wait before resuming normal operation after lr has been reduced
    verbose: bool
        False: quiet, True: update messages
    """

    def __init__(
        self,
        initial_learning_rate: pytensor.shared,
        factor=0.1,
        patience=10,
        min_lr=1e-6,
        cooldown=0,
        verbose=True,
    ):
        self.learning_rate = initial_learning_rate
        self.factor = factor
        self.patience = patience
        self.min_lr = min_lr
        self.cooldown = cooldown
        self.verbose = verbose
        self.cooldown_counter = 0
        self.wait = 0
        self.best = float("inf")
        self.old_lr = None

    def __call__(self, approx, loss_hist, i):
        current = loss_hist[-1]

        if np.isinf(current):
            return

        if self.in_cooldown():
            self.cooldown_counter -= 1
            self.wait = 0
            return

        if current < self.best:
            self.best = current
            self.wait = 0
        elif not np.isinf(self.best):
            self.wait += 1
            if self.wait >= self.patience:
                self.reduce_lr()
                self.cooldown_counter = self.cooldown
                self.wait = 0

    def reduce_lr(self):
        old_lr = float(self.learning_rate.get_value())
        if old_lr > self.min_lr:
            new_lr = max(old_lr * self.factor, self.min_lr)
            self.learning_rate.set_value(new_lr)
            if self.verbose:
                print(
                    f"Reduced learning rate to {new_lr} after {self.patience} iterations without improvement."
                )

    def in_cooldown(self):
        return self.cooldown_counter > 0

This callback can then be nicely combined with CheckParametersConvergence() for early stopping:

import matplotlib.pyplot as plt
import numpy as np
from pymc.variational.callbacks import Callback, CheckParametersConvergence, Tracker

# Toy data
length = 720
rng = np.random.default_rng(1337)
x = np.linspace(1e-2, 1, num=length)
true_regression_line = 5 * x + 4
y = true_regression_line + rng.normal(0, 1, size=length)
y[rng.integers(0, length, size=10)] += rng.normal(0, 4, size=10)
y = (y - y.mean()) / y.std()

# Model with early stopping and learning rate scheduling
with pm.Model() as lr_model:
    sigma = pm.HalfCauchy("sigma", beta=10)
    intercept = pm.Normal("intercept", 0, sigma=20)
    slope = pm.Normal("slope", 0, sigma=20)
    likelihood = pm.Normal("y", mu=intercept + slope * x, sigma=sigma, observed=y)

    learning_rate = pytensor.shared(1e-1, "learning_rate")
    optimiser = pm.adam(learning_rate=learning_rate)
    tracker = Tracker(lr=lambda: optimiser.keywords["learning_rate"].get_value())

    fit = pm.fit(
        obj_optimizer=optimiser,
        callbacks=[
            tracker,
            CheckParametersConvergence(),
            ReduceLROnPlateau(
                initial_learning_rate=learning_rate,
                factor=0.1,
                patience=10,
                min_lr=1e-4,
                cooldown=10,
                verbose=True,
            ),
        ],
        random_seed=1337,
    )

_, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].plot(np.arange(len(fit.hist)), fit.hist)
axs[0].set_yscale("log")
axs[0].set_title("Fit history")
axs[1].plot(tracker.hist["lr"], ".-")
axs[1].set_yscale("log")
axs[1].set_title("LR")
plt.tight_layout()

output

alvaropp avatar Nov 14 '23 13:11 alvaropp

This is a good strategy, thanks for taking the time to work on this. Feel free to open a pull request and move the discussion there.

fonnesbeck avatar Nov 14 '23 13:11 fonnesbeck

+1 for a full PR.

My only concerns with the current approach are:

  1. I don't think users should have to know to make the learning rate a shared variable to use the feature (this needs to be checked and handled automatically somehow)
  2. Is this approach robust to arbitrary combinations of learning rate schedulers? If I want, for example, a cosine schedule combined with a linear decrease, can this be done using callbacks only?

jessegrabowski avatar Nov 14 '23 14:11 jessegrabowski

Can most of this be abstracted to something like pm.rate_scheduler which is then passed to pm.fit?

fonnesbeck avatar Dec 01 '23 02:12 fonnesbeck

@fonnesbeck we had some back-and-forth on discourse here about using callbacks (what was implemented in the PR) vs a wrapper around an optimizer (that is then passed to pm.fit).

jessegrabowski avatar Dec 01 '23 21:12 jessegrabowski

Duplicate of #6954

fonnesbeck avatar Jun 14 '24 21:06 fonnesbeck