unilm
unilm copied to clipboard
Diff Attention out_proj dimension fix
Subtle potential issue with self.out_proj = nn.Linear(embed_dim, embed_dim) working on attn.reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim in the rare case that embed_dim != self.num_heads * 2 * self.head_dim
Small fix in the init of the base and flash attention implementations.
@microsoft-github-policy-service agree