Adan-pytorch icon indicating copy to clipboard operation
Adan-pytorch copied to clipboard

m0 / v1 init

Open rwightman opened this issue 2 years ago • 15 comments

Allo @lucidrains , I've been fiddling with this optimizer, looking promising so far. I was looking for other interpretations out there for my doubts re no bias correction... I'm assuming it's deemed unecessary due to the explicit m0 and v1 init, but wasn't 100% sure it wasn't just left out for clarity.

I noticed you left m0 as zero, and v1 as interpolation with zero init... did you experiment with that vs the notes in paper, Algorithm 1?

The core of my attempt below (note I flipped the betas to be comparable to adam/lamb/etc: .98, .92, .99)

    state = self.state[p]
    if len(state) == 0:
        state['step'] = 0
        state['grad'] = torch.zeros_like(grad)
        state['m'] = torch.clone(grad)  # init m0 = g0
        state['v'] = torch.zeros_like(grad)
        state['n'] = torch.zeros_like(grad)

    m, v, n = state['m'], state['v'], state['n']
    # NOTE first step is no-op as we need g0 & g1 for first grad delta (g1 - g0)
    if state['step'] > 0:
        m.lerp_(grad, 1. - beta1)
        grad_delta = grad - state['grad']
        if state['step'] > 1:
            v.lerp_(grad_delta, 1. - beta2)
        else:
            v.copy_(grad_delta)  # init v1 = g1 - g0
        n.lerp_((grad + beta2 * grad_delta).square(), 1. - beta3)

        # FIXME paper Algorithm 1 includes no bias correction
        # Does m0 and v1 init special cases obliviate the need or was left out of paper for clarity?
        denom = 1 + group['weight_decay'] * lr
        step_size = lr * (n + group['eps']).rsqrt()
        p.addcmul_(step_size, m.add(v, alpha=beta2), value=-1.).div_(denom)

    state['grad'].copy_(grad)
    state['step'] += 1

rwightman avatar Aug 26 '22 21:08 rwightman

@rwightman Oh hey Ross! Glad to see you are keeping up with the bleeding edge :) I just sent you an email a moment ago about your CLIP experiment

So I did do it the way in the paper initially, but I had to initialize n to grad ** 2 as was done in the restart condition https://github.com/lucidrains/Adan-pytorch/commit/14ec8b31b90c57df9ce9a9a151ec833c0854e989#diff-61c9ea3d62e9746a1092013f1c4d8804f28e654e6bb00da8cd98a527bedc7139R53 for it not to explode for my task (which is a small GPT)

However, I was chatting with @sdtblck and he told me he zero initted everything, so I tried it, and could not see a difference. So I just left it like that for simplicity

Are you seeing a big difference following the careful init as in the pseudocode?

lucidrains avatar Aug 26 '22 22:08 lucidrains

I'm seeing really poor results if state['n'] is initialized as zeros as you have in your code

lucidrains avatar Aug 26 '22 22:08 lucidrains

Screenshot from 2022-08-26 15-24-39

not very rigorous, but blue is adam (baseline), red is with the careful init with grad squared init, purple is with zero init, and brown is the careful init, but without grad squared

lucidrains avatar Aug 26 '22 22:08 lucidrains

@lucidrains I've also been doing some not very scientific comparisons (restart train with same seed) and see what happens for in the case of one network (a vit-cnn hybrid), one random init. But I am seeing what you are so far

Careful init w/ n == 0 is not great. All zeros is better. Now trying careful + n == grad ** 2...

rwightman avatar Aug 26 '22 22:08 rwightman

@rwightman awesome! everyone would be eager to hear your results, which is much more authoritative than my toy tasks haha

lucidrains avatar Aug 26 '22 22:08 lucidrains

@lucidrains so, two network archs now, running through the variations, all zeros with no special case init definitely appears to be the winner in these tests of limited scope. Hmm...

rwightman avatar Aug 26 '22 23:08 rwightman

Sorry for making something confused here. Adan indeed has the bias correction in the implementation, but we need to consist the algorithm presentation with the theoretical analysis. Hence, we did not explicitly emphasize it in Algorithm1. We'll release the code in a few days (2-3 days since we have a code review procedure). The log and config files will release together. @rwightman

XingyuXie avatar Aug 29 '22 11:08 XingyuXie

@XingyuXie Hi Xingyu and thanks for the interesting paper

I tested out the bias correction and indeed seeing a slight improvement https://github.com/lucidrains/Adan-pytorch/commit/3911a86e41624a5048e687e18d451b3fd5007242 Let me know if you see anything else that does not look quite right!

lucidrains avatar Aug 29 '22 15:08 lucidrains

@lucidrains Thanks for updating, the following are some minor modifications. When we implement Adan, we refer to some optimizer's implementation in timm.

Line 55: state['prev_grad'] = grad Line 85-86:

correct_m = 1 / bias_correct1  # correction term for m'
correct_v = 1 / bias_correct2  # correction term for v

Line 91:

weighted_step_size = lr / ((n.sqrt()/sqrt_bias_correct3).add_(eps))

Tips:

  • For fairness and ease of use, we do not enable the restart condition in practice.
  • Adan can tolerate a large peak LR. For example, except for the experiments for the pre-training of MAE and LSTM, Adan's LR is 5-10 times that of Adam/AdamW.
  • Adan seems to be relatively sensitive to beta3. Adjusting beta1 and beta2 has a limited effect on the results, especially beta2.
  • Interestingly, we found that weight_decay = 0.02 seems to be suitable for most experiments.

XingyuXie avatar Aug 30 '22 03:08 XingyuXie

@XingyuXie thanks for the code review!

lucidrains avatar Aug 30 '22 04:08 lucidrains

@lucidrains You're welcome.

By increasing LR and tuning the warm steps, the performance may be further improved. Have fun using Adan.

XingyuXie avatar Aug 30 '22 05:08 XingyuXie

@XingyuXie i using the optimizer visualization for verification, it feels that the adan algorithm is less robust than other algorithms. and my tf implementation is here, and visualization

cpuimage avatar Aug 30 '22 07:08 cpuimage

Sorry that I am not quite familiar with TF, I have tried to add your Wechat to send our adan.py (implemented with PyTorch). No response has been received yet. @cpuimage


Thank @cpuimage I have visualized Adan on two toy cases. But we must point out that practical performance is more important. Since a bunch of optimizers (e.g., Adabound and Yogi) can handle the two cases. But they vary a lot in practical DNN training.

rastrigin_Adan rosenbrock_Adan

XingyuXie avatar Aug 31 '22 05:08 XingyuXie

I took Adan for a spin today, it looks promising. I am just training my fav timm backbone convnext_tiny with nice results. The only downside I see is that it's slower, quite slower than Adam. https://wandb.ai/capecape/adan_optimizer/reports/Adan-The-new-optimizer-that-challenges-Adam--VmlldzoyNTQ5NjQ5 There is also a nice implementation for fastai by Benjamin Warner here: https://github.com/warner-benjamin/fastxtend/blob/main/nbs/optimizers.adan.ipynb

tcapelle avatar Aug 31 '22 08:08 tcapelle

Hi, @tcapelle It can be seen from the experimental results released by you here that the Acc. of Adan's three trials are 71.8/ 75.5/ 74.0, while the results of Adam's three trials are 72.2/71.4/71.5.

It seems that this result is not consistent with the curve drawn in the blog. But it is also possible that I missed some key details.

We really appreciate your detailed experiments and suggestions.

BZW, our code has been released at: https://github.com/sail-sg/Adan.

It also contains the config files and results of ConvNext. You may refer to and welcome any feedback.

XingyuXie avatar Sep 01 '22 12:09 XingyuXie