DALLE2-pytorch icon indicating copy to clipboard operation
DALLE2-pytorch copied to clipboard

decoder generates noise images

Open xinmiaolin opened this issue 3 years ago • 26 comments

Hi,

I am trying to train a decoder on CUB-200 dataset. However, after nearly 400 epochs, the decoder can still only generate noises and I have been trying a long time to figure out why. I would appreciate your suggestions!

Screen Shot 2022-06-23 at 10 32 41 AM

So the training loss decreases pretty steadily, but the images still look like this: Screen Shot 2022-06-23 at 10 33 41 AM

I have only used 1 unet and the configuration is: dim = 128, image_embed_dim = 768, dim_mults = (1, 2, 4, 8),, for the decoder, the parameters are: image_size=64, timesteps = 1000, image_cond_drop_prob = 0.1, text_cond_drop_prob = 0.5, learned_variance = False.

Thank you!

xinmiaolin avatar Jun 23 '22 17:06 xinmiaolin

@xinmiaolin hi, it looks like your training run diverged early

the loss should go down to around 0.05 before the images come into view

how high is your learning rate?

lucidrains avatar Jun 23 '22 18:06 lucidrains

The learning rate is 4e-3. Yes, the training loss does drop very precipitously from around 1 to 0.1 in the first epoch, then sharply increased to 0.8 again. Then, the training loss decreases steadily. The 0.05 is the loss per image you mean? The batch size is 128 that I used.

Thank you very much!

xinmiaolin avatar Jun 23 '22 18:06 xinmiaolin

@xinmiaolin do you mean 3e-4 because 4e-3 is absurdly high!

lucidrains avatar Jun 23 '22 18:06 lucidrains

@xinmiaolin would recommend 1e-4

lucidrains avatar Jun 23 '22 18:06 lucidrains

@xinmiaolin yea, the loss should be the MSE, which is average across image samples

lucidrains avatar Jun 23 '22 18:06 lucidrains

The learning rate is 4e-3. Yes, the training loss does drop very precipitously from around 1 to 0.1 in the first epoch, then sharply increased to 0.8 again. Then, the training loss decreases steadily. The 0.05 is the loss per image you mean? The batch size is 128 that I used.

Thank you very much!

yup, when losses jump up like that, it means the training is unhealthy. some of these jumps are not recoverable

lucidrains avatar Jun 23 '22 18:06 lucidrains

I will try a lower learning rate to see if this kind of jumps appear again.

xinmiaolin avatar Jun 23 '22 18:06 xinmiaolin

@xinmiaolin you'll definitely see something

DDPMs are so much easier to train than their predecessor, GANs

lucidrains avatar Jun 23 '22 18:06 lucidrains

@xinmiaolin you'll definitely see something

DDPMs are so much easier to train than their predecessor, GANs

Thanks, this definitely gives me hope haha

xinmiaolin avatar Jun 23 '22 18:06 xinmiaolin

@xinmiaolin did it work?

lucidrains avatar Jun 24 '22 15:06 lucidrains

Hi, I don't think that it is working...

Here is when lr=1e-4, although it is diverging, but the loss can still decrease. Screen Shot 2022-06-24 at 11 28 16 AM

This is when lr=3e-6, the loss oscillates a lot and does not decrease at all, even increases. Screen Shot 2022-06-24 at 11 25 47 AM

I am confused why when the lr is smaller, the loss function has such great fluctuations. I am now training with lr=1e-5, it looks good. Will keep on updating!

xinmiaolin avatar Jun 24 '22 18:06 xinmiaolin

@xinmiaolin how high is your batch size?

lucidrains avatar Jun 24 '22 18:06 lucidrains

@xinmiaolin how high is your batch size?

the batch size is 128

xinmiaolin avatar Jun 24 '22 18:06 xinmiaolin

ohh that should be good enough

ok, keep at it with 1e-5, i'd be surprised if that didn't work

lucidrains avatar Jun 24 '22 18:06 lucidrains

ohh that should be good enough

ok, keep at it with 1e-5, i'd be surprised if that didn't work

ok, I will come back and update. Thanks

xinmiaolin avatar Jun 24 '22 18:06 xinmiaolin

@xinmiaolin i'll add a learning rate warmup for the decoder trainer some time this weekend; just need to figure out how to make it huggingface accelerate compatible

lucidrains avatar Jun 24 '22 19:06 lucidrains

@xinmiaolin i'll add a learning rate warmup for the decoder trainer some time this weekend; just need to figure out how to make it huggingface accelerate compatible

Ok thanks!

xinmiaolin avatar Jun 24 '22 21:06 xinmiaolin

This is the training of lr=1e-5, I have also used CosineAnnealingLR with t_max=10000. The loss increases from around 0.3 to 0.6 then does not decrease at all.

Screen Shot 2022-06-26 at 10 59 54 AM

For lr=3e-5, the loss increases from around 0.2 to 0.7, then decreases slowly. Screen Shot 2022-06-26 at 11 04 27 AM

I have tried smaller learning rates without the lr scheduler, but the loss also fluctuates, for example, when lr=3e-7.

Screen Shot 2022-06-26 at 11 07 45 AM

I am not sure what is going on. Could there be some parameters related to diffusion models that should be changed, because I am using default values of the parameters.

xinmiaolin avatar Jun 26 '22 18:06 xinmiaolin

@xinmiaolin hmm, could you possibly send me your full training script?

lucidrains avatar Jun 26 '22 19:06 lucidrains

@xinmiaolin hmm, could you possibly send me your full training script?

sure, thanks!

xinmiaolin avatar Jun 26 '22 21:06 xinmiaolin

Did you check your optimizer? I think it maybe because your parameters are not doing the backward. I met the problem before, the loss decreased slowly but I figured out there is something wrong in optimizer.py.

YUHANG-Ma avatar Jul 06 '22 01:07 YUHANG-Ma

Did you check your optimizer? I think it maybe because your parameters are not doing the backward. I met the problem before, the loss decreased slowly but I figured out there is something wrong in optimizer.py.

Hi thank you for the suggestion. I will check on it.

xinmiaolin avatar Jul 07 '22 00:07 xinmiaolin

@xinmiaolin Have you solved the problem or have you trained the model yet, I have the same problem

QinSY123 avatar Nov 04 '22 14:11 QinSY123

Facing the same problem, would love to know if you found out the issue! :)

cc: @xinmiaolin

thecooltechguy avatar May 19 '23 18:05 thecooltechguy

@xinmiaolin Same for me, did you found out the root cause of that behaviour ?

canelle20214 avatar Jun 18 '23 14:06 canelle20214