annotated_deep_learning_paper_implementations icon indicating copy to clipboard operation
annotated_deep_learning_paper_implementations copied to clipboard

Bug in SA for DDPM UNet?

Open FutureXiang opened this issue 3 years ago • 3 comments

https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/05632f9f8e0de4657c210a13954a81f9556fd1ed/labml_nn/diffusion/ddpm/unet.py#L188

According to my understanding of Self-Attention, the softmax operation should be done along the j axis in einsum?

https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/05632f9f8e0de4657c210a13954a81f9556fd1ed/labml_nn/diffusion/ddpm/unet.py#L190

So, I think the code should be attn = attn.softmax(dim=2). Please correct me if I am wrong.

However, the Attention module (with bug?) seems to work somehow, at least on the CIFAR-10 dataset.

FutureXiang avatar Sep 07 '22 17:09 FutureXiang

Thanks, you are right! That's a typo and a big bug!

vpj avatar Sep 11 '22 12:09 vpj

Thank you for your response!

I compare the FID results on classifier-free guidance conditional CIFAR-10 with attn = attn.softmax(dim=1) and attn = attn.softmax(dim=2), following the settings in the original DDPM paper.

However, I observe no difference between the bugged and the correct model (FID ±0.05), and the correct model performs even worse than the bugged one when trained with more iterations (FID -0.02~0.3).

Have any idea why it happens?

FutureXiang avatar Sep 11 '22 14:09 FutureXiang

This is strange. I guess the wrong softmax also provides a similar non-linearity to the correct softmax and gradient descent finds a way to use it. But I don't understand how the wrong softmax becomes better than the correct one. Wonder how well no attention will perform. I will also try to run some tests.

Btw, I pushed the fix https://github.com/labmlai/annotated_deep_learning_paper_implementations/commit/7d1550dd67ef903959a6367165e17c15b89da5c8

vpj avatar Sep 12 '22 03:09 vpj