torch icon indicating copy to clipboard operation
torch copied to clipboard

Investigate issues with nn_multihead_attemntion

Open dfalbel opened this issue 3 years ago • 2 comments

See #496

dfalbel avatar Mar 22 '21 11:03 dfalbel

There appear to still be some issues with this function that appear when the q,k,v tensors are not all the same (so the tests miss it). For example:

library(torch)

embed_dim <- 32L
num_heads <- 4L

multihead_attn <- nn_multihead_attention(embed_dim = embed_dim, num_heads = num_heads)

batch_size <- 8L
seq_len <- 5L
query <- torch_randn(seq_len, batch_size, embed_dim)
key <- torch_randn(seq_len, batch_size, embed_dim)
value <- torch_randn(seq_len, batch_size, embed_dim)

out <- multihead_attn(query, query, query)
out <- multihead_attn(query, key, key)
out <- multihead_attn(query, key, value)

gives output:

> out <- multihead_attn(query, query, query)
> out <- multihead_attn(query, key, key)
 Error in (function (self, size)  : 
  shape '[-1, 32, 8]' is invalid for input of size 1320 
> out <- multihead_attn(query, key, value)
 Error in (function (self, size)  : 
  shape '[-1, 32, 8]' is invalid for input of size 1320 

jonathanbratt avatar Apr 29 '21 18:04 jonathanbratt

I think I found the problem. Will submit PR shortly.

jonathanbratt avatar Apr 29 '21 19:04 jonathanbratt