d2l-en icon indicating copy to clipboard operation
d2l-en copied to clipboard

multi-head Attention code has a big problem.

Open Y-H-Joe opened this issue 2 years ago • 2 comments

I only checked the pytorch version. ################################################################## class MultiHeadAttention(nn.Module): """Multi-head attention. Defined in :numref:sec_multihead-attention""" def init(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).init(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) ######## should not be 'query_size' self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) #######should not be 'key_size' self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) ####### should not be 'value_size' self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

def forward(self, queries, keys, values, valid_lens):
    # Shape of `queries`, `keys`, or `values`:
    # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
    # Shape of `valid_lens`:
    # (`batch_size`,) or (`batch_size`, no. of queries)
    # After transposing, shape of output `queries`, `keys`, or `values`:
    # (`batch_size` * `num_heads`, no. of queries or key-value pairs,
    # `num_hiddens` / `num_heads`)
    queries = transpose_qkv(self.W_q(queries), self.num_heads) ######## here, the last dime of queries is num_hiddens !
    keys = transpose_qkv(self.W_k(keys), self.num_heads)
    values = transpose_qkv(self.W_v(values), self.num_heads)

    if valid_lens is not None:
        # On axis 0, copy the first item (scalar or vector) for
        # `num_heads` times, then copy the next item, and so on
        valid_lens = torch.repeat_interleave(
            valid_lens, repeats=self.num_heads, dim=0)

    # Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
    # `num_hiddens` / `num_heads`)
    output = self.attention(queries, keys, values, valid_lens)

    # Shape of `output_concat`:
    # (`batch_size`, no. of queries, `num_hiddens`)
    output_concat = transpose_output(output, self.num_heads)
    return self.W_o(output_concat)

##################################################### When training, if you change the num_hiddens from 32 to 64, you will get "RuntimeError: mat1 dim 1 must match mat2 dim 0". After debugging, I found in the MultiheadAttetion block, in the forward function, the shape of X is (batch_size, no. of queries or key-value pairs, num_hiddens) see the num_hiddens is the last dime But the self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) the first dim of W_q is query_size ! So in this case, you always have to make num_hiddens = query_size to run. Which is obviously wrong. ####################################################### My suggestion is to change self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) ==> self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=bias)

But there maybe another solution.

If my understanding is wrong, please correct me.

d2l is wonderful for sure.

P.S. The way for building a large sing-head attention and then bend it into multi-head is not elegant, it would be much better if your guys could find another solution.

Y-H-Joe avatar Mar 03 '22 18:03 Y-H-Joe

I found this problem too,self.W_q = nn.Linear(query_size, num_hidden, bias=bias) which means the H_in should be query_size, but the input tensor's shape is (batch_size, num_queries, num_hidden),apparently the last dim is num_hidden,not query_size it works because, in this example, we made num_hidden = query_size. if you set num_hidden not equal to query_size, you will probably get something wrong like mat1 and mat2 shapes cannot be multiplied

batman47steam avatar Apr 19 '22 11:04 batman47steam

Hi, Sorry for not commenting earlier due to bandwidth, we've been refactoring the complete attention chapter. I'll reply back once we've looked into these issues. Feel free to open up PRs if you'd like to take a shot at it yourself. Happy to help :))

Thanks for your patience.

AnirudhDagar avatar Apr 19 '22 12:04 AnirudhDagar