Subtle bug in disentangled_attention_bias?
I think there may be a subtle bug in disentangled_attention_bias.
The HuggingFace implementation of this code is a more straightforward reproduction of Eqn (4) from the disentangled attention paper.
The implementation here tries to use a computational trick to reuse the embedding indices c2p_pos, which are computed for content-to-position, in the block for position-to-content.
I'm worried about these lines:
p2c_att = torch.bmm(pos_query_layer.to(key_layer)*scale, key_layer.transpose(-1, -2))
p2c_att = torch.gather(p2c_att, dim=-2, index=c2p_pos)
To be clear: I understand why this looks backwards, compared to the eqn in the paper. Using dim=-2 rather than dim=-1 in the gather effectively takes the transpose of the matrix product. That's completely fine.
But why is it safe to re-use c2p_pos here, effectively using _delta(i,j) rather than _delta(j,i). Transposing the Q matrix doesn't mean the row index _delta(i,j) changes to _delta(j,i).
The HuggingFace implementation computes a separate embedding indexing tensor for p2c.
I imagine this is the kind of thing that could go unnoticed, because it should have a relatively minor effect on results.