xformers icon indicating copy to clipboard operation
xformers copied to clipboard

Memory efficient way to encode varying input sequence lengths in self attention

Open thoglu opened this issue 1 year ago • 12 comments

Hi, I have the issue that my sequences are strongly varying in length, with sometimes having outliers that are an order of magnitude longer than the average. In default pytorch one can only pass a padding mask with the max length, and as far as I understand my current situation, that will lead to a memory usage that is rather large, because one effectively calculates attention with the max length. Is this the same in xformers? In "memory efficient self attention", one cannot pass a mask, so it seems that one only works for similar-length sequences.

My guess is that padding all series to "max_length" ensures one can efficiently parallelize on the GPU for the dot product, but I wonder if there are any implementations that focus on memory efficiency, and do not require padding, but rather work similar to "packed_sequences" for RNNs in pytorch?

❓ Questions and Help

thoglu avatar Mar 05 '23 12:03 thoglu

Hi,

Indeed padding sequences is not efficient for multiple reasons. What you should do instead is concatenate all of your sequences together, and when you do the attention you pass a special mask to memory_efficient_attention. There is some illustration in the documentation there, along with an example: https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.BlockDiagonalMask

danthe3rd avatar Mar 06 '23 09:03 danthe3rd

Cool I will check that out , thanks :).. maybe it would be helpful to link to this from the README, it is not straight forward to find this in the documentation.

thoglu avatar Mar 06 '23 13:03 thoglu

Hi @danthe3rd , may I ask what if I do want padding sequences? I'm working on vision tasks, the lengths of sequence do not vary a lot. Currently, I'm creating attn_bias Tensor myself.

function2-llx avatar Jul 26 '23 02:07 function2-llx

@function2-llx can you describe with more details what your bias look like? Like with an example for instance. It's usually better if you can avoid padding, and in fact we don't support any form of padding during training

danthe3rd avatar Jul 27 '23 15:07 danthe3rd

@danthe3rd My use case is as follows:

Suppose there are three images in a batch with shapes $(h_1, w_1)$, $(h_2, w_2)$, and $(h_3, w_3)$. After applying the patch embedding of vision transformers, they become sequences of lengths $L_1, L_2, L_3$. Let's say $L_1$ is the longest length among this batch. The other sequences will be padded to the length of $L_1$. The attention bias for the $i$-th image will be a matrix of size $L_1 \times L_1$. The top-left $L_i \times L_i$ elements will be filled with 0, and the rest will be filled with $-\infty$.

And yes, you're right that it will be better if we can avoid padding. I think another alternative should be to put images of similar shape into the same batch.

function2-llx avatar Jul 28 '23 06:07 function2-llx

So if you want to use xFormers like this, you will need to cat patches together. Eg instead of having a shape [B, max(L_i), D], something like [1, sum(L_i), D], and the information about batches will be in the bias, see: https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.BlockDiagonalMask

danthe3rd avatar Jul 28 '23 07:07 danthe3rd

Thank you for your suggestion. Yes, I've checked the BlockDiagonalMask. But my concern is: will it have much more computation cost compared to padding when the lengths of the sequence are close? E.g., if padding to $L_{\max}=\max L_i$, the computation cost is proportional to $BL_{\max}^2$, while with concatenation, it is $\left(\sum L_i\right)^2 = B^2L_{\text{average}}^2$. When the lengths are close, $L_{\text{average}}$ will be close to $L_{\max}$ and the computation cost is about $B$ times higher.

function2-llx avatar Jul 28 '23 07:07 function2-llx

Hum that's not how it works. It's not an additive bias that gets added after the entire attention matrix is calculated. This specific bias type is recognized by the kernel, which will only compute the areas necessary in the attention matrix. So this bias will always be faster than padding.

danthe3rd avatar Jul 28 '23 08:07 danthe3rd

Wow that's cool, thanks for clarification!

function2-llx avatar Jul 28 '23 08:07 function2-llx

Hi, @function2-llx , @danthe3rd Sorry to bother you, I would like to change xformers.ops.fmha.BlockDiagonalMask.from seqlens([N] * B, mask) to a general implementation of pytorch. Can you provide me with some guidance? Thank you

tzayuan avatar Apr 02 '24 04:04 tzayuan

a general implementation of pytorch

What do you mean?

danthe3rd avatar Apr 02 '24 13:04 danthe3rd

a general implementation of pytorch

What do you mean?

Problem solved. Thans a lot!

tzayuan avatar Apr 07 '24 02:04 tzayuan