big_vision icon indicating copy to clipboard operation
big_vision copied to clipboard

implement gsam in jax

Open juntang-zhuang opened this issue 2 years ago • 3 comments

Hi, @lucasb-eyer thanks for your review and comments. I reformated the files and squashed commits into a new PR (sorry I messed up the old PR and could not squash commits there). This PR includes:

  1. Put GSAM related configs into config.gsam and call gsam with l, grads = gsam_gradient(loss_fn=loss_fn, base_opt=opt, inputs=images, targets=labels, lr=learning_rate, **config["gsam"])
  2. Add big_vision/configs/proj/gsam/vit_1k_gsam_no_aug.py, the network used in GSAM paper used pool_type='gap' and rep_size=False, which is different from the default config.
  3. Fix format issues and squash commits.

Regarding reproducing the experiments, I wonder if it's possible for you to run the script (with 8x8 TPU cores to exactly match the paper)? I'm sorry I don't have access to TPU resources since I'm not affiliated with Google now, so I can't run experiments, though the checkpoints and the old version code that I used were kept in server. Thanks so much for your code review and help!

juntang-zhuang avatar Jul 16 '22 00:07 juntang-zhuang

Regarding running experiments, I could give it a try at some point, but definitely impossible to do so this week. I would run exactly the config you provide, and you need to tell me exactly which number in which table of the paper it's supposed to reproduce.

lucasb-eyer avatar Jul 19 '22 16:07 lucasb-eyer

Regarding running experiments, I could give it a try at some point, but definitely impossible to do so this week. I would run exactly the config you provide, and you need to tell me exactly which number in which table of the paper it's supposed to reproduce.

Thanks a lot! If the effective wd schedule is not figured out, I might need to find some way to either implement the old versioned weight decay schedule, or tune the hyper-param with the new setting. I wonder if you could point Ting to the docs on how to run this repository internally, and I'll submit codes from external, so we could re-run some experiments to reproduce?

juntang-zhuang avatar Jul 20 '22 06:07 juntang-zhuang

hey, sorry I got distracted by something urgent to finish, will get back to this in one of the next two weeks and am optimistic we can get it to work well :)

edit: however, you did not yet tell me which exact number from the paper the config should be reproducing?

lucasb-eyer avatar Aug 05 '22 19:08 lucasb-eyer

Thanks for the response. Sorry about the missing number, it's supposed to reproduce the 76.8 for ViT-B/32 in Table 1 of https://openreview.net/pdf?id=edONMAnhLu- .

I'm not fully sure about the new wdecay and lr scheduler. In the old version, lr scheduler is a single function (here lr scheduler func seems to be chained with a bunch of other schedulers); in the old version, wdecay is multiplied by lr, so wdecay is actually a scheduler rather than constant, is the new wdecay set to a constant?

juntang-zhuang avatar Aug 08 '22 03:08 juntang-zhuang

oh, and you have a bunch of small issues like wrong indentations, trailing spaces, etc. It would be helpful if you could run pylint with this config over it, then I don't need to fix these later on.

lucasb-eyer avatar Aug 08 '22 20:08 lucasb-eyer

and another minor nitpick: could you rename the config from ...1k... to ...i1k...? Because we never call ImageNet 1k, but always i1k in the whole codebase. I assume you made a typo.

lucasb-eyer avatar Aug 08 '22 21:08 lucasb-eyer

Here is training_loss of running this config, sweeping over wd=0.0009 (=0.3*0.003, should be exact same as in paper), 0.001 (nicer number close to previous one), and 0.3 (just in case). The loss is crazy, accuracy is and stays at random (not shown): image

However, I find the fact that it starts at 693.15, roughly 100x the standard starting-loss of i1k (log1000=6.907) somewhat suspicious. I noticed the config is using sigmoid_xent loss, your paper does not mention the words "softmax" or "sigmoid" ; could it be that you trained with softmax_xent and have sigmoid_xent here in the config by mistake? I'll try a run with that instead, but please take another careful read over the config and see if you can find other sources of this.

Another thing, the config does not contain the config.init_head_bias, which we often, but not always, use. Could this also be a mistake? (I'll also schedule an experiment about this).

lucasb-eyer avatar Aug 09 '22 11:08 lucasb-eyer

Thanks a lot for the experiments, seems the config is not correct. I'll discuss it with Ting and see if we can directly compare the config file with the one we used for experiments.

juntang-zhuang avatar Aug 09 '22 16:08 juntang-zhuang

So far, no luck with any of (sigmoid->softmax, head-bias init, ) made it any better.

Then, I also tried the follwing things:

  1. Disable weight-decay altogether, to check whether I can at least overfit. Nope, still an exploding loss, so the issue seems unrelated to wd(?)
  2. Model with cls-token and mlp-head (repr_size=True), as this was original vit. A complete disaster :)

So, I tried all the ideas I had regarding configuration, and at this point wonder if maybe there's a bug in the implementation. Could you please try on your side? Note that you don't need TPU access to run big_vision, it works great on GPUs too, we did update the README with instructions about that. Let me know when you figure out a setting/code change such that the loss does not explode in the first hundreds of steps anymore, and I can then try longer runs for you again. (I'll also ping Ting my runs internally).

lucasb-eyer avatar Aug 09 '22 20:08 lucasb-eyer

I forgot to mention, but I also tried a run with adam 1t momentum not in bfloat16, but in regular float32, and it makes no difference. Note this bfloat16 really just affects the 1st momentum buffer, nothing else.

lucasb-eyer avatar Aug 11 '22 09:08 lucasb-eyer

Ting shared with me your exact runs from the paper numbers, so I could dig in a bit more. Carefully replicating exactly the config that was run, I still get similar behaviour, though slightly less extreme ("only" going up to hundreds, not millions): image

At this point, I feel like this must be a bug in the code. It seems to go wrong after ~500 steps, potentially you can even run that on CPUs to debug?

lucasb-eyer avatar Aug 11 '22 21:08 lucasb-eyer

Thanks a lot for the feedback and experiments, I'll dig it out with Ting, and will post the working version here. Sorry for all the trouble with this PR.

juntang-zhuang avatar Aug 12 '22 22:08 juntang-zhuang

Sorry for all the trouble with this PR

No worries, I will be happy and thankful to have up-to-date GSAM and SAM in the codebase!

lucasb-eyer avatar Aug 18 '22 14:08 lucasb-eyer

I also tried to run this with alpha=0, and it looks slightly better at the start, but still explodes after 1-2k step.

evcu avatar Aug 18 '22 18:08 evcu

I just noticed in one of your changes a few days ago, you did find a bug:

    learning_rate = sched_fns[0](step)   # Wrong
    learning_rate = sched_fns[0](step) * config["lr"]   # Your fix

This looks very promising! So I patched it in and tried another run on top of the last one I mentioned here. It looks a lot better! It doesn't explode, and reaches 75.2/81.8/61.0 validation/real/v2 accuracy after 90 epochs. This not yet the expected 76.8/82.7/63.0 we're trying to reproduce, but it's getting much closer :partying_face:

However, the missing 1.6% are still significant, so we should find them before merging this. I carefully compared configs (already before, but once again) and didn't find a new discrepancy. With alpha=0 I should get SAM, right? Were the SAM and Vanilla numbers in Table1 also produced by you, or copied from somewhere? If produced by you, I could also run SAM and Vanilla and see if I can reproduce them, it would give us an indication where the remaining mistake can be.

Here are a few metrics, do they all look reasonable to you? image

lucasb-eyer avatar Aug 18 '22 21:08 lucasb-eyer

@lucasb-eyer Thanks so much for running experiments! I'm also running an experiment on ViT-S/32, but takes much longer on my GPU machine, will also post results here after it finishes.

The results for SAM are copied from https://arxiv.org/abs/2106.01548 table 2. For the gap of 1.6%, it might come from

  • in the paper it trains for 300 epochs (here's 90) for ViT,
  • a bug related to point 2 below
  • I used 8x8 TPU cores for most experiments, for SAM-family a larger TPU core number typically increases performance.

In previous updates, I made a few changes that potentially make a difference, including the following:

  1. pass the absolute learning rate learning_rate = sched_fns[0](step) * config["lr"] instead of learning_rate = sched_fns[0] (step)
  2. in config.gsam sets absolute values to lr_max=config.get_ref('lr') and lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr')
  3. in config.schedule set linear_end=0.01 (rather than linear_end=0.00003)
  4. pass flax.jax_utils.replicate(step) when calling update_fn

(I'm not sure if 4 is necessary, just following my old code after meeting with Ting.)

For 1, it's my fault that I did not realize bv_optax defines the learning rate schedule in a relative manner, while all my code last year assumes the lr are all absolute values. This causes a bug in my previous PR, that I passed absolute lr to denominator, but relative lr to the denominator, which results in about 300x larger perturbation amplitude. Such a big perturbation would crash the network. In current version this should be fixed.

For 2 and 3, it's also caused by my mistake with lr schedule. To reproduce the paper results, the absolute learning rate is a linear decay with max_lr=0.003 and min_lr=0.00003. Switching to the relative ratio schedule, should be linear_end=0.01.

I have merged the changes above in the latest PR, let me know if you have time to take a look. I'm also reproducing a ViT-S/32 results with my machine, it's a bit slow but will post it here once I get results. Thanks again for your help with this!

juntang-zhuang avatar Aug 19 '22 04:08 juntang-zhuang

No need to blame yourself alone, I also should have noticed ALL of these during review and testing, but didn't :) Happy you found them now! Let me start some runs right away, for 300ep, and report back later today.

I actually ran all experiments on 8x8, but am curious why TPU topology would influence the results?

lucasb-eyer avatar Aug 19 '22 08:08 lucasb-eyer

Cool, I'm really excited to see the updated results, they outperform numbers in the paper! I have updated PR according to your comments, except the step is passed to update_fn rather than read out from opt.

One minor thing is, GSAM reduces to SAM requires alpha=0 and rho_max=rho_min in the gsam_gradient function, basically SAM uses a constant perturbation rho_t, GSAM scales rho_t proportional to learning rate schedule. It might not be a good idea to set constant by setting rho_max=rho_min, maybe using a bv_optax style schedule function is a better idea for code style consistency.

For TPU number, it's because that GSAM / SAM performs per-worker perturbation based on per-worker gradient in gsam_gradient, more workers will have more different perturbations, so the model effectively see more neighbors in the parameter space.

juntang-zhuang avatar Aug 19 '22 20:08 juntang-zhuang

Thanks for your comments. My "SAM" run with rho_max=rho_min=0.15 just finished, and it's quite a bit better than the paper number too. From my reading of the code, when rho_max=rho_min then we do use a constant rho value independent of learning-rate (schedule), no? image

And yes, making it use the actual schedule_fn from optax would be ideal, then we could simply use SAM with all kinds of schedules, and we don't need to manually specify lr_min/lr_max in the config anymore. That would be a lot better, but I thought that I already asked a lot from you, so didn't want to ask for that too :) If you want to do it, that's great, otherwise I may do it at some point, or maybe never, if we never need it. But this is the largest argument against having it in the core trainer for now.

lucasb-eyer avatar Aug 19 '22 20:08 lucasb-eyer

Regarding the perturbations per host, I noticed that the model souping paper states that not syncing may have a significant disadvantage: image

so it may be worth implementing. Do I understand correctly that it basically means doing jax.lax.pmean(g_clean)?

lucasb-eyer avatar Aug 19 '22 20:08 lucasb-eyer

I also just realized that we should add a pointer to this from the README. I'll do so early next week too.

lucasb-eyer avatar Aug 19 '22 20:08 lucasb-eyer

Thanks so much for your help with the debug and PR!

Regarding the rho_t schedule, yes it is constant when rho_max=rho_min, I implemented it in a way that rho_t follows the same schedule as lr_t (except they have difference value scales). It might be better to pass rho_t as another sched_fn, but I'm not familiar with the chain style fn in bv_optax, so I'm not confident to implement correctly and matching the existing code base.

For per-worker perturbation, the model soup paper seems to contradict the original SAM paper https://arxiv.org/pdf/2010.01412.pdf section 4.1. It defines m-sharpness where m is the per-worker number of examples. A smaller m (hence a larger worker number when total batchsize is fixed) improves generalization.

I'm not quite sure about model soup implementations. In my implementation (and SAM), the process is:

  1. per-worker gradient g_clean (not synced) and per-worker perturbation param_sam https://github.com/google-research/big_vision/blob/136deda7827aa52143905afc0482b79b7438c8f8/big_vision/trainers/proj/gsam/gsam.py#L69
  2. per-worker gradient g_gsam at (per-worker) perturbed model weights param_sam https://github.com/google-research/big_vision/blob/136deda7827aa52143905afc0482b79b7438c8f8/big_vision/trainers/proj/gsam/gsam.py#L91
  3. average g_gsam across workers in https://github.com/google-research/big_vision/blob/136deda7827aa52143905afc0482b79b7438c8f8/big_vision/trainers/proj/gsam/train.py#L211 note the returned grads here is g_gsam (not g_clean) in the gsam_gradient function.
  4. all workers update with the same value of globally averaged gsam in optimizer.

I'm not quite sure with model soup, but I suspect if it draws an opposite conclusion from SAM paper, it might come from a different implementation. For example, if it switches the order of 3 and 4, first performs per-worker parameter update with per-worker g_gsam, then average model weights across workers, this might harm performance compared to synced perturbation.

If want to perform synced perturbation, we can add g_clean = jax.pmean(g_clean) after https://github.com/google-research/big_vision/blob/136deda7827aa52143905afc0482b79b7438c8f8/big_vision/trainers/proj/gsam/gsam.py#L56 so that param_sam is the same for all workers

juntang-zhuang avatar Aug 20 '22 06:08 juntang-zhuang