Bug in SA for DDPM UNet?
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/05632f9f8e0de4657c210a13954a81f9556fd1ed/labml_nn/diffusion/ddpm/unet.py#L188
According to my understanding of Self-Attention, the softmax operation should be done along the j axis in einsum?
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/05632f9f8e0de4657c210a13954a81f9556fd1ed/labml_nn/diffusion/ddpm/unet.py#L190
So, I think the code should be attn = attn.softmax(dim=2).
Please correct me if I am wrong.
However, the Attention module (with bug?) seems to work somehow, at least on the CIFAR-10 dataset.
Thanks, you are right! That's a typo and a big bug!
Thank you for your response!
I compare the FID results on classifier-free guidance conditional CIFAR-10 with attn = attn.softmax(dim=1) and attn = attn.softmax(dim=2), following the settings in the original DDPM paper.
However, I observe no difference between the bugged and the correct model (FID ±0.05), and the correct model performs even worse than the bugged one when trained with more iterations (FID -0.02~0.3).
Have any idea why it happens?
This is strange. I guess the wrong softmax also provides a similar non-linearity to the correct softmax and gradient descent finds a way to use it. But I don't understand how the wrong softmax becomes better than the correct one. Wonder how well no attention will perform. I will also try to run some tests.
Btw, I pushed the fix https://github.com/labmlai/annotated_deep_learning_paper_implementations/commit/7d1550dd67ef903959a6367165e17c15b89da5c8