DALLE-pytorch icon indicating copy to clipboard operation
DALLE-pytorch copied to clipboard

"adamw" optimizer + weight decay = poor generations

Open afiaka87 opened this issue 3 years ago • 12 comments

https://github.com/lucidrains/DALLE-pytorch/discussions/139#discussioncomment-560790

It appears as though adamw does work better but the weight decay is creating strange generations.

Getting the same strange "brown" generations even though the loss continues to go down. It does so at a pretty slow rate - and if you're working with --fp16 it's tough to know the generations are poor until after training due to the inability to submit images through wandb.

afiaka87 avatar Apr 06 '21 20:04 afiaka87

Is a good temporary solution to this to just set the weight_decay parameter to zero? @kobiso said as much but I assumed that effectively just turns it into a plain ole adam optimizer? Out of my depth.

afiaka87 avatar Apr 07 '21 05:04 afiaka87

@lucidrains Noticed the adamw removal. Should I keep this open since it's from the paper?

afiaka87 avatar Apr 07 '21 06:04 afiaka87

yeap, let's keep it open since it's from the paper :)

kobiso avatar Apr 07 '21 12:04 kobiso

The default weight_decay is .0 anyway, isn't it?

robvanvolt avatar Apr 07 '21 19:04 robvanvolt

@robvanvolt the default weight_decay is 0, but DALLE paper used 4.5*10-2.

kobiso avatar Apr 08 '21 01:04 kobiso

@kobiso @lucidrains @robvanvolt

So - I'm not having this problem anymore. I'm not sure exactly when we fixed it, but I can no longer reproduce this issue.

These are two of a bunch of good samples I'm getting training on a t-shirts dataset.

I tried to follow the paper (with regard to the optimizer).

opt = AdamW(dalle.parameters(), lr=LEARNING_RATE, betas=(0.9,0.96), weight_decay=4.5e-2, amsgrad=True)

I also have found a decent learning rate to be 3.7e-4. That's what I used here.

Due to experimentation and sunk cost fallacy, this network has the attention types:

attn_types=('full', 'axial_row', 'axial_col', 'full')

media_images_image_701_b2f731fa media_images_image_1301_9e6df8f2

afiaka87 avatar Apr 27 '21 05:04 afiaka87

https://github.com/lucidrains/DALLE-pytorch/pull/220

@lucidrains this has been steadily improving my results. I say we put it back in.

afiaka87 avatar Apr 29 '21 06:04 afiaka87

Okay AdamW with the OpenAI defaults is merged back in:

https://github.com/lucidrains/DALLE-pytorch/pull/220

afiaka87 avatar Apr 29 '21 19:04 afiaka87

Hm - so I realize now that the problem is actually that the state of the optimizer and scheduler is not stored on the model for resuming. If you have both AdamW and LR_Decay turned on, and try to resume - the scheduler will start with a learning rate tuned for the beginning of training, causing the bad generations.

@janEbert is that in your deepspeed fix branch?

afiaka87 avatar May 01 '21 13:05 afiaka87

Yeah, DeepSpeed by default loads (and saves) the optimizer and LR scheduler states. So the DeepSpeed checkpoints do not have this problem with the default settings.

The default non-DeepSpeed checkpoints are not suited for resuming, only for inference!

janEbert avatar May 03 '21 08:05 janEbert

I'm advising we get rid of AdamW from the main codebase again - I was flat wrong about it working again unfortunately.

Here's a PR which does so. https://github.com/lucidrains/DALLE-pytorch/pull/227

Screenshot from 2021-05-03 09-32-33 Okay! I have no clue what's causing this actually because it didn't happen with a resume involved.

This is such a subtle thing to catch because requires you to run a few epochs to see it happening. Here's a full run where I did not use resume and the problem still occurs.

https://wandb.ai/afiaka87/starting_over/reports/Weight-Decay-Bug--Vmlldzo2NTgxMjA?accessToken=rigmz991xq7blj8fuesbwtrnmyi86nsjranwmphzj79unjx8ilu4akjow2pqd86i

afiaka87 avatar May 03 '21 14:05 afiaka87

Any updates on this? Could I use Adamw with weight decay? I got a similar result in brown

shizhediao avatar Feb 23 '22 00:02 shizhediao