fairseq
fairseq copied to clipboard
Assertion compares dimension of key_padding_mask with query dimension in xformers MHA
❓ Questions and Help
What is your question?
I came across this assertion: https://github.com/facebookresearch/fairseq/blob/3f6ba43f07a6e9e2acf957fc24e57251a7a3f55c/fairseq/modules/multihead_attention.py#L385
Which compares the sequence length dimension of key padding mask with tgt_len
, which is the sequence length dimension of the query. This check fails if the sequence length dimensions of key and query are different (for e.g. in cross-attention).
Shouldn't the check here be: key_padding_mask.size(1) == key.size(0)
?
Code
tgt_len, bsz, embed_dim = query.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == tgt_len
Just a dumb question. I am training a transformer model using fairseq
and want to use xformers
.
Is it enough if I install xformers
library in my environment and start training, or do I need to pass any additional arguments?