mha.py array shapes
I wonder why array shapes in aha are (C, B, D) rather than (B, C, D). I thought it was convention that the batch was the first dimension. Specially, here are the first few lines of the forward method of class MultiHeadAttention:
def forward(self, *,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None):
"""
`query`, `key` and `value` are the tensors that store
collection of *query*, *key* and *value* vectors.
They have shape `[seq_len, batch_size, d_model]`. <<<<<<<<
`mask` has shape `[seq_len, seq_len, batch_size]` and
`mask[i, j, b]` indicates whether for batch `b`,
query at position `i` has access to key-value at position `j`.
"""
Thanks.
same question, with the original code in the class MultiHeadAttention in mha.py. Cause the following logic, the softmax will operate cross batch, which I don't understand. Need help.
# the defination of softmax
self.softmax = nn.Softmax(dim=1)
# the usage of softmax
attn = self.softmax(scores) # here scores have a shape of [seq_q, seq_k, heads, d_k]
Our implementation has sequence first. PyTorch LSTM used that and in our initial implementations we used C B H and just continued with it. B C D is more commonly used now and is faster since implementations like flash attention iterates over sequence dimension.