pytorch-gradual-warmup-lr icon indicating copy to clipboard operation
pytorch-gradual-warmup-lr copied to clipboard

multiplier works weird

Open alexwongdl opened this issue 2 years ago • 2 comments

While I modify the example code like this:

import torch
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.optim.sgd import SGD

from warmup_scheduler import GradualWarmupScheduler


if __name__ == '__main__':
    model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
    optim = SGD(model, 0.0001)

    # scheduler_warmup is chained with schduler_steplr
    scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
    scheduler_warmup = GradualWarmupScheduler(optim, multiplier=10, total_epoch=5, after_scheduler=scheduler_steplr)

    # this zero gradient update is needed to avoid a warning message, issue #8.
    optim.zero_grad()
    optim.step()

    for epoch in range(1, 20):
        scheduler_warmup.step(epoch)
        print(epoch, optim.param_groups[0]['lr'])

        optim.step()    # backward pass (update network)

I get an unexcepted result, the sixth epoch is strange

1 0.00028
2 0.00045999999999999996
3 0.00064
4 0.00082
5 0.001
6 0.0001    
7 0.001
8 0.001
9 0.001
10 0.001
11 0.001
12 0.001
13 0.001
14 0.001
15 0.0001
16 0.0001
17 0.0001
18 0.0001
19 0.0001

alexwongdl avatar May 07 '22 08:05 alexwongdl

I can confirm this behavior. I think line 31 in warmup_scheduler/scheduler.py is troublesome, and that

return self.after_scheduler.get_last_lr()

should rather be:

return self.after_scheduler.get_lr()

I do however think the whole scheduler would be easier / less error-prone to implement using the built-in PyTorch LR scheduler LinearLR for the warmup part, optionally chained with one or more other schedulers (the equivalent of "after_scheduler") using SequentialLR.

lucasbrynte avatar Dec 06 '22 13:12 lucasbrynte

Just to nuance my comment: For some reason it actually seems users are not supposed to call the .get_lr() function. It generates a warning message if called from elsewhere than .step(), in which case this is indicated by a with _enable_get_lr_call(self): statement.

lucasbrynte avatar Dec 06 '22 15:12 lucasbrynte