Suboptimal default initialization of q/k/v projections in `nn.MultiHeadDotProductAttention`
Initialization of q/k/v projections are not forward-backward normalized for linen.MultiHeadDotProductAttention. This implementation does not face optimization issues when a pre-LN variant of transformer is used; but faces convergence issues in the vanilla post-LN variant from "Attention is All You Need".
System information
- Flax version:
flax==0.8.4
Problem you have encountered:
Most papers over the past 4 years still use the vanilla post-LN transformer. One such is facebookresearch/detr. Inputs to the self-attention block on the first decoder is all zeros, as shown below:
# src: https://github.com/facebookresearch/detr/blob/29901c51d7fe8712168b8d0d64351170bc0f83e0/models/transformer.py#L55
...
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
...
The default linen.MultiHeadDotProductAttention does not converge at all for the first 10 epochs. The reason is initialization and gradient norm behavior of q/k/v matrices, particularly when inputs to the MHDPA block are all zeros. I have a diff here which converges from the very first epoch.
Please note: I have isolated the problem to be in the kernel initializer alone, which this proposal is about. Custom implementation of MHDPA does not have any impact on this proposal.
This convergence issue does not exist in PyTorch, because q/k/v are:
xavier_uniforminitialized, and- with appropriate
fan_inandfan_outvalues.
What you expected to happen:
Expected linen.MultiHeadDotProductAttention to converge from the first epoch, as is the case in facebookresearch/detr
Fix:
- Switch to
xavier_uniforminitializer for projections (versusdefault_kernel_initwhich islecun_normal). This is also standard best practice. t-fixup paper - Use correct
fan_invalue for the initializer (for same embedding dimensions of q/k/v,fan_inshould be3 * embed_dim) ref1, ref2
Here is an approximate diff needed for change in Flax:
https://github.com/MasterSkepticista/detr/commit/995f335237b72cf17fb1e187b8cb6faf5d51e784
There could be multiple ways to go about calculating correct fan_in, or using a giant dense layer as PyTorch does it. This has performance implications.
Let me know if the fix in my code (barring the hardcoded values) is a reasonable approach?
Happy to do a PR.
Hey, thanks for looking into this! I think your analysis is correct and would favor using the xavier_uniform initializer with the scaled fan_in to emulate a bigger matmul. However, I worked on #3893 internally for a bit and learned that changing the initialization logic is very hard as it will break tons of tests that rely on the current defaults so merging this is a challenge given finite resources.
In practice, we encourage users to fork our layers and adapt them to their needs which is why this hasn't manifested as a strong issue. It is still worth thinking about this issue, I'll try to discuss this issue internally.
Hi @cgarciae. Thanks for responding.
Finding this bug took out my 3 weeks of dry spells while replicating PyTorch version. I think at a minimum, a mention of this initialization scheme in MHDPA docstrings would go a long way :)
On the potential test failures: per my understanding, this change would be localized to MultiHeadDotProductAttention, isn't it? What is the nature of those test failures? Are these training baselines? Just curious.