annotated_deep_learning_paper_implementations icon indicating copy to clipboard operation
annotated_deep_learning_paper_implementations copied to clipboard

mha.py array shapes

Open erlebach opened this issue 1 year ago • 1 comments

I wonder why array shapes in aha are (C, B, D) rather than (B, C, D). I thought it was convention that the batch was the first dimension. Specially, here are the first few lines of the forward method of class MultiHeadAttention:

    def forward(self, *,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        """
        `query`, `key` and `value` are the tensors that store
        collection of *query*, *key* and *value* vectors.
        They have shape `[seq_len, batch_size, d_model]`.      <<<<<<<<

        `mask` has shape `[seq_len, seq_len, batch_size]` and
        `mask[i, j, b]` indicates whether for batch `b`,
        query at position `i` has access to key-value at position `j`.
        """

Thanks.

erlebach avatar Jul 13 '24 02:07 erlebach

same question, with the original code in the class MultiHeadAttention in mha.py. Cause the following logic, the softmax will operate cross batch, which I don't understand. Need help.

# the defination of softmax
self.softmax = nn.Softmax(dim=1)
# the usage of softmax
attn = self.softmax(scores) # here scores have a shape of [seq_q, seq_k, heads, d_k]

dingyue772 avatar Nov 14 '24 03:11 dingyue772

Our implementation has sequence first. PyTorch LSTM used that and in our initial implementations we used C B H and just continued with it. B C D is more commonly used now and is faster since implementations like flash attention iterates over sequence dimension.

vpj avatar Jul 18 '25 05:07 vpj