torch
torch copied to clipboard
Investigate issues with nn_multihead_attemntion
See #496
There appear to still be some issues with this function that appear when the q,k,v tensors are not all the same (so the tests miss it). For example:
library(torch)
embed_dim <- 32L
num_heads <- 4L
multihead_attn <- nn_multihead_attention(embed_dim = embed_dim, num_heads = num_heads)
batch_size <- 8L
seq_len <- 5L
query <- torch_randn(seq_len, batch_size, embed_dim)
key <- torch_randn(seq_len, batch_size, embed_dim)
value <- torch_randn(seq_len, batch_size, embed_dim)
out <- multihead_attn(query, query, query)
out <- multihead_attn(query, key, key)
out <- multihead_attn(query, key, value)
gives output:
> out <- multihead_attn(query, query, query)
> out <- multihead_attn(query, key, key)
Error in (function (self, size) :
shape '[-1, 32, 8]' is invalid for input of size 1320
> out <- multihead_attn(query, key, value)
Error in (function (self, size) :
shape '[-1, 32, 8]' is invalid for input of size 1320
I think I found the problem. Will submit PR shortly.