fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

Assertion compares dimension of key_padding_mask with query dimension in xformers MHA

Open sarthakgarg opened this issue 1 year ago • 1 comments

❓ Questions and Help

What is your question?

I came across this assertion: https://github.com/facebookresearch/fairseq/blob/3f6ba43f07a6e9e2acf957fc24e57251a7a3f55c/fairseq/modules/multihead_attention.py#L385 Which compares the sequence length dimension of key padding mask with tgt_len, which is the sequence length dimension of the query. This check fails if the sequence length dimensions of key and query are different (for e.g. in cross-attention). Shouldn't the check here be: key_padding_mask.size(1) == key.size(0)?

Code

        tgt_len, bsz, embed_dim = query.size()

        if key_padding_mask is not None:
            assert key_padding_mask.size(0) == bsz
            assert key_padding_mask.size(1) == tgt_len

sarthakgarg avatar Apr 13 '23 17:04 sarthakgarg

Just a dumb question. I am training a transformer model using fairseq and want to use xformers. Is it enough if I install xformers library in my environment and start training, or do I need to pass any additional arguments?

VarunGumma avatar Apr 30 '23 02:04 VarunGumma