DALLE-pytorch
DALLE-pytorch copied to clipboard
"adamw" optimizer + weight decay = poor generations
https://github.com/lucidrains/DALLE-pytorch/discussions/139#discussioncomment-560790
It appears as though adamw does work better but the weight decay is creating strange generations.
Getting the same strange "brown" generations even though the loss continues to go down. It does so at a pretty slow rate - and if you're working with --fp16 it's tough to know the generations are poor until after training due to the inability to submit images through wandb.
Is a good temporary solution to this to just set the weight_decay
parameter to zero? @kobiso said as much but I assumed that effectively just turns it into a plain ole adam optimizer? Out of my depth.
@lucidrains Noticed the adamw removal. Should I keep this open since it's from the paper?
yeap, let's keep it open since it's from the paper :)
The default weight_decay is .0 anyway, isn't it?
@robvanvolt the default weight_decay is 0, but DALLE paper used 4.5*10-2.
@kobiso @lucidrains @robvanvolt
So - I'm not having this problem anymore. I'm not sure exactly when we fixed it, but I can no longer reproduce this issue.
These are two of a bunch of good samples I'm getting training on a t-shirts dataset.
I tried to follow the paper (with regard to the optimizer).
opt = AdamW(dalle.parameters(), lr=LEARNING_RATE, betas=(0.9,0.96), weight_decay=4.5e-2, amsgrad=True)
I also have found a decent learning rate to be 3.7e-4. That's what I used here.
Due to experimentation and sunk cost fallacy, this network has the attention types:
attn_types=('full', 'axial_row', 'axial_col', 'full')
https://github.com/lucidrains/DALLE-pytorch/pull/220
@lucidrains this has been steadily improving my results. I say we put it back in.
Okay AdamW with the OpenAI defaults is merged back in:
https://github.com/lucidrains/DALLE-pytorch/pull/220
Hm - so I realize now that the problem is actually that the state of the optimizer and scheduler is not stored on the model for resuming. If you have both AdamW and LR_Decay turned on, and try to resume - the scheduler will start with a learning rate tuned for the beginning of training, causing the bad generations.
@janEbert is that in your deepspeed fix branch?
Yeah, DeepSpeed by default loads (and saves) the optimizer and LR scheduler states. So the DeepSpeed checkpoints do not have this problem with the default settings.
The default non-DeepSpeed checkpoints are not suited for resuming, only for inference!
I'm advising we get rid of AdamW from the main codebase again - I was flat wrong about it working again unfortunately.
Here's a PR which does so. https://github.com/lucidrains/DALLE-pytorch/pull/227
Okay! I have no clue what's causing this actually because it didn't happen with a resume involved.
This is such a subtle thing to catch because requires you to run a few epochs to see it happening. Here's a full run where I did not use resume and the problem still occurs.
https://wandb.ai/afiaka87/starting_over/reports/Weight-Decay-Bug--Vmlldzo2NTgxMjA?accessToken=rigmz991xq7blj8fuesbwtrnmyi86nsjranwmphzj79unjx8ilu4akjow2pqd86i
Any updates on this? Could I use Adamw with weight decay? I got a similar result in brown