x-transformers icon indicating copy to clipboard operation
x-transformers copied to clipboard

Memory Efficiency w.r.t Sequence Length

Open adamoyoung opened this issue 4 years ago • 5 comments
trafficstars

I am a bit of a noob when it comes to transformers. If I want to encode a batch of N sequences of maximum length L, my understanding is that I do something like this:

from x_transformer import Encoder, TransformerWrapper
seqs = ['aba','cb','abcab']
N = len(seqs)
L = max(len(seq) for seq in seqs)
C = 3
padded_seqs = get_padded_seqs(seqs) # N x L long tensor
mask = get_seq_mask(seqs) # N x L boolean tensor
encoder = TransformerWrapper(num_tokens=C,max_seq_len=L,attn_layers=Encoder())
embeddings = encoder(padded_seqs,mask=mask,return_embeddings=True)

In this transformer implementation, would there be a difference in memory usage if all of the sequences were of length L (i.e. all the mask values were True)?

adamoyoung avatar Oct 04 '21 14:10 adamoyoung

My guess is there is no difference, based on how the masks are used in the Attention class

adamoyoung avatar Oct 04 '21 14:10 adamoyoung

@adamoyoung nope, no difference! you could strategically construct your batches to minimize padding tokens to maximize efficiency, but most practitioners never do so

lucidrains avatar Oct 04 '21 16:10 lucidrains

Thanks! Do you know if other implementations tend to do this as well? In pytorch_geometric they allow for graph batching where the memory usage scales with the number of nodes/edges actually in the batch, not the maximum number of nodes/edges that are allowed in a single graph (which is analogous to the sequence length). They do this by implementing the attention with scatter/gather operations instead of masked matrix multiplications. I'm wondering if this would be a good idea for transformers, and if you know of anyone who has tried this.

adamoyoung avatar Oct 04 '21 17:10 adamoyoung

@adamoyoung yea, the transformers community went a very different direction than that of graph neural nets and how it is approached with PyG. we typically don't do it the scatter/gather way, though I have met researchers who were interested in writing CUDA kernels to remove attention on the padding. i think batching by similar lengths is a good middle ground that i've seen used by others (one such implementation i came across https://github.com/jonathanking/sidechainnet/blob/4d4f57204c162ab938b8762dfacffb1d992774d0/sidechainnet/dataloaders/SimilarLengthBatchSampler.py#L9 )

lucidrains avatar Oct 04 '21 18:10 lucidrains

Thanks, that's a good solution! Will check it out.

adamoyoung avatar Oct 04 '21 18:10 adamoyoung