vision icon indicating copy to clipboard operation
vision copied to clipboard

Investigate if lr_scheduler from segmentation can use PyTorch's schedulers

Open fmassa opened this issue 2 years ago • 8 comments

Back when it was initially implemented in 2019, the LR scheduler in the segmentation reference scripts couldn't be implemented with native PyTorch schedulers, so we had to resort to LambdaLR https://github.com/pytorch/vision/blob/9275cc61fb3c26ce15ced0199ad8b7540d48676c/references/segmentation/train.py#L136-L138

It might be that this is now available in PyTorch natively, and this can be simplified.

cc @datumbox

fmassa avatar Sep 17 '21 10:09 fmassa

Is there any polynomial learning rate scheduler implemented yet? I can see this issue still open - https://github.com/pytorch/pytorch/issues/2854

cc: @fmassa @datumbox

avijit9 avatar Sep 29 '21 13:09 avijit9

I had a look also when Francisco raised the ticket but couldn't see anything compatible TBH.

datumbox avatar Sep 29 '21 14:09 datumbox

It might not be implemented yet. I think we should check to see if this type of scheduler has been used in more papers since then, that could justify adding it to PyTorch.

fmassa avatar Sep 29 '21 14:09 fmassa

I'm removing the "good first issue" tag because I think there isn't such a scheduler on Core and more thorough investigation would be needed to resolve. Perhaps coordinating with Core to add it is worth it but that's not a great Bootcamp task.

datumbox avatar Nov 02 '21 13:11 datumbox

Hi guys,

I'm working on this issue as reported here. However, I think I need to know some extra information about the expected behavior of the scheduler.

So far, I have considered the following resources:

  1. The current implementation in torchvision's train script.
  2. https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/PolynomialDecay
  3. https://github.com/cmpark0126/pytorch-polynomial-lr-decay

Let's see one by one.

I'm going to fix some parameters, to make a fair comparison.

lr = 1e-3
end_learning_rate = 1e-4
max_decay_step = 4
power = 1.0
data_loader = range(0, 5)

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
  torch.optim.SGD([v], lr=lr),
  lambda step: (1 - step / len(data_loader)) ** 1.0
)

>>> for i in data_loader:
>>>    lr_scheduler.step(i)

>>>    print(i, optimizer.param_groups[0]['lr'])

0 0.001
1 0.0008
2 0.0006
3 0.0004
4 0.00019999999999999996
poly = tf.keras.optimizers.schedules.PolynomialDecay(
    lr,
    max_decay_step,
    end_learning_rate= end_learning_rate,
    power=power,
    cycle=False,
    name=None
)

>>> for i in range(0, 5):
>>>    print(i, poly(i))

0 tf.Tensor(0.001, shape=(), dtype=float32)
1 tf.Tensor(0.00077499996, shape=(), dtype=float32)
2 tf.Tensor(0.00055, shape=(), dtype=float32)
3 tf.Tensor(0.000325, shape=(), dtype=float32)
4 tf.Tensor(1e-04, shape=(), dtype=float32)
scheduler = PolynomialLRDecay(
     torch.optim.SGD([torch.zeros(10)], lr=lr),
    max_decay_steps=max_decay_steps,
    end_learning_rate= end_learning_rate,
    power=power,
)

>>> for i in range(0, 5):
>>>    scheduler.step()
>>>    print(i, optim.param_groups[0]['lr'])

0 0.00055
1 0.000325
2 0.0001
3 0.0001
4 0.0001

>>> for i in range(0, 5):
>>>    scheduler.step(i)
>>>    print(i, optim.param_groups[0]['lr'])

0 0.0007750000000000001
1 0.0007750000000000001
2 0.00055
3 0.000325
4 0.0001

Open issues:

  • I have noticed that scheduler.step(epoch) is being/going to be deprecated. Right now, it's handled by the _get_closed_form_lr method, if available. Should we continue to support it? Moreover, the behavior of scheduler.step(epoch) and scheduler.step() should be the same, right?
  • Looking at the implementation of _LRScheduler, it seems that a step is performed just by instantiating the scheduler. This means that we're like "skipping" one learning rate decay value. Is this what we want?
  • Considering the above example, what are the expected/correct LR values?

@datumbox

federicopozzi33 avatar Jul 29 '22 22:07 federicopozzi33

@federicopozzi33 These are all very good questions. Unfortunately I wasn't too familiar with the API of Schedulers so in order to answer them I had to implement it and experiment.

Here is the proposed implementation:

import warnings

import torch
from torch.optim.lr_scheduler import _LRScheduler


class PolynomialLR(_LRScheduler):
    def __init__(self, optimizer, total_iters=5, min_lr=0.0, power=1.0, last_epoch=-1, verbose=False):
        self.total_iters = total_iters

        if isinstance(min_lr, list) or isinstance(min_lr, tuple):
            if len(min_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} min_lrs, got {}".format(len(optimizer.param_groups), len(min_lr)))
            self.min_lrs = list(min_lr)
        else:
            self.min_lrs = [min_lr] * len(optimizer.param_groups)

        self.power = power
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
            )

        if self.last_epoch == 0:
            return [group["lr"] for group in self.optimizer.param_groups]

        if self.last_epoch > self.total_iters:
            return [self.min_lrs[i] for i in range(len(self.optimizer.param_groups))]

        return [
            self.min_lrs[i]
            + ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters))
            ** self.power
            * (group["lr"] - self.min_lrs[i])
            for i, group in enumerate(self.optimizer.param_groups)
        ]

    def _get_closed_form_lr(self):
        return [
            (
                self.min_lrs[i]
                + (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power
                * (base_lr - self.min_lrs[i])
            )
            for i, base_lr in enumerate(self.base_lrs)
        ]


# Test it
lr = 0.001
total_iters = 5
power = 1.0

scheduler = PolynomialLR(
    torch.optim.SGD([torch.zeros(1)], lr=lr),
    total_iters=total_iters,
    min_lr=0.0,  # Using 0 because the Lambda doesn't support this option
    power=power,
)
scheduler2 = torch.optim.lr_scheduler.LambdaLR(
    torch.optim.SGD([torch.zeros(1)], lr=lr), lambda step: (1 - step / total_iters) ** power
)


for i in range(0, total_iters):
    print(i, scheduler.optimizer.param_groups[0]["lr"], scheduler2.optimizer.param_groups[0]["lr"])
    scheduler.step()
    scheduler2.step()

Here are some answers to your questions:

  1. Yes indeed. The safest approach is to inherit step from _LRScheduler and get around the problem all together.
  2. It seemed that the first iteration was skipped because you were printing AFTER the step. The step() is called at the very end and is the one that updates the LR value. The problem was previously masked when you passed explicitly the epoch value.
  3. The expected LR values are from LambdaLR. The above implementation produces the expected output:
0 0.001 0.001
1 0.0008 0.0008
2 0.0006 0.0006
3 0.0004 0.0004
4 0.00019999999999999996 0.00019999999999999996

Though I think we can use the above implementation as-is, to be able to contribute it to PyTorch core we need tests, docs and a few more bells and whistles. I believe the PR https://github.com/pytorch/pytorch/pull/60836 is a good example of what needs to be done. If you are up for it, you can start a PR and I can help you get it merged. Alternatively, I can finish it off and find you a different primitive. Let me know what you prefer.

datumbox avatar Aug 02 '22 14:08 datumbox

Hi @datumbox,

thank you for your help.

I have some doubts about the meaning of min_lr. If I've correctly understood, it has two meanings:

  1. It is the lower bound on the learning rate, i.e. LR will never be lower than min_lr.
  2. LR is set to min_lr if last_epoch > total_iters.

I didn't find any references for some parts of the formula you used for the decayed LR. Although the values seem correct to me, I have some doubts about the part:

((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters))

Could you explain me better?


It seemed that the first iteration was skipped because you were printing AFTER the step. The step() is called at the very end and is the one that updates the LR value. The problem was previously masked when you passed explicitly the epoch value.

Ok, I get what you mean, but I was referring to this.


Though I think we can use the above implementation as-is, to be able to contribute it to PyTorch core we need tests, docs and a few more bells and whistles. I believe the PR https://github.com/pytorch/pytorch/pull/60836 is a good example of what needs to be done. If you are up for it, you can start a PR and I can help you get it merged. Alternatively, I can finish it off and find you a different primitive. Let me know what you prefer.

Yeah, I'm putting the pieces together . I will open a PR soon.

federicopozzi33 avatar Aug 03 '22 19:08 federicopozzi33

Correct the min_lr is the minimum permitted value for LR. I'm not 100% we have to support this TBH. Let's see what the Core team says and if there are any weird interactions we should keep in mind.

Although it seems correct to me, I have some doubts about the part:

The API of Schedulers is a bit weird. The changes on the get_lr() happen in place so you need to undo the update of the previous epoch and apply the new one.

Yeah, I'm putting the pieces together . I will open a PR soon.

Sounds good, make sure you tag me on the PR.

datumbox avatar Aug 03 '22 19:08 datumbox

The scheduler has been implemented (see https://github.com/pytorch/pytorch/pull/82769).

It remains only to update the segmentation training script using the newly implemented scheduler as soon as a new version of PyTorch is released.

federicopozzi33 avatar Aug 11 '22 14:08 federicopozzi33

That's correct. In fact once the Scheduler makes it to the nightly, we can make the change. Not sure if it made it to the one today or if it will appear tomorrow, but you can start a PR and I'll review/test/merge soon. Would that work for you?

datumbox avatar Aug 11 '22 14:08 datumbox