oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

The custom lr_scheduler form of oneflow and pytorch is different

Open clackhan opened this issue 3 years ago • 4 comments
trafficstars

pytorch version: 1.12.1+cu102 oneflow:

version: 0.8.1+cu112.git.c0811b327a
git_commit: c0811b327a
cmake_build_type: Debug
rdma: False
mlir: False

As is shown in the next code, the args of func get_lr is different between oneflow and pytorch

oneflow:

import oneflow as flow
from oneflow.optim.lr_scheduler import _LRScheduler

class WarmupLR(_LRScheduler):
    def __init__(
        self,
        optimizer: flow.optim.Optimizer,
        warmup_steps: Union[int, float] = 25000,
        last_epoch: int = -1,
    ):
        assert check_argument_types()
        self.warmup_steps = warmup_steps

        super().__init__(optimizer, last_epoch)

    ...

    def get_lr(self, base_lr, step):
        step_num = step + 1
        return base_lr * self.warmup_steps**0.5 * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
import torch
from torch.optim.lr_scheduler import _LRScheduler

class WarmupLR(_LRScheduler):
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        warmup_steps: Union[int, float] = 25000,
        last_epoch: int = -1,
    ):
        assert check_argument_types()
        self.warmup_steps = warmup_steps
        super().__init__(optimizer, last_epoch)
    ...

    def get_lr(self):
        step_num = self.last_epoch + 1
        return [
            lr
            * self.warmup_steps**0.5
            * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
            for lr in self.base_lrs
        ]

clackhan avatar Sep 19 '22 08:09 clackhan

这个我和智敏来对齐 @small1945

wyg1997 avatar Sep 21 '22 07:09 wyg1997

这个实现不一样有什么问题?https://github.com/Oneflow-Inc/OneTeam/issues/998 这里论述了为什么不选择和 pytorch 一样的实现。

但计算结果应该是一样的,https://github.com/Oneflow-Inc/oneflow/blob/2d60b9a69353c79022e898cb6aaf5f39a44cbfb2/python/oneflow/test/modules/test_lr_scheduler.py#L322 这里有 warmup 的单测,如果有发现没对齐的 case,可以列举实际的测试数据出来,我来 fix 并添加单测 case。

leaves-zwx avatar Sep 21 '22 11:09 leaves-zwx

还有个问题,issue 里面描述的关于 WarmupLR 的类的实现,我搜了 pytorch 最新的 master,并没有相关实现,不知道是哪里引用的实现呢?

leaves-zwx avatar Sep 21 '22 12:09 leaves-zwx

@clackhan 可以继续讨论下这个问题

BBuf avatar Sep 27 '22 01:09 BBuf