DeepSpeed
DeepSpeed copied to clipboard
support batch size dimension in 2D sparse attention mask
The purpose of this PR is to enable different attention masks per mini-batch in the sparse attention module. Generally, sentences are of different length and so it doesn't really make sense to assume that the 2D attention mask is constant, as implemented currently.
Background: I am (ab-)using the sparse attention as a cross attention module by padding one of the inputs to the length of the other to make the attention matrix square. Now since both inputs come from different sentences with different sentence lengths, a non-symmetric 2D attention mask is the general case. However, this way I can still work with the built-in assumption of the block sparse operations that the key-query matrix is square.
I had to slightly generalize the input preparation since now the batch size is another dimension which should not be squeezed.
Can one of the admins verify this patch?
I do not need this functionality anymore, so will only be able to provide limited guidance. Feel free to close if noone else needs this.
Thanks, @jglaser - closing this now then since its fairly old. Appreciate your contributions and next time we will do a much better job reviewing things promptly.