x-transformers
x-transformers copied to clipboard
Memory Efficiency w.r.t Sequence Length
I am a bit of a noob when it comes to transformers. If I want to encode a batch of N sequences of maximum length L, my understanding is that I do something like this:
from x_transformer import Encoder, TransformerWrapper
seqs = ['aba','cb','abcab']
N = len(seqs)
L = max(len(seq) for seq in seqs)
C = 3
padded_seqs = get_padded_seqs(seqs) # N x L long tensor
mask = get_seq_mask(seqs) # N x L boolean tensor
encoder = TransformerWrapper(num_tokens=C,max_seq_len=L,attn_layers=Encoder())
embeddings = encoder(padded_seqs,mask=mask,return_embeddings=True)
In this transformer implementation, would there be a difference in memory usage if all of the sequences were of length L (i.e. all the mask values were True)?
My guess is there is no difference, based on how the masks are used in the Attention class
@adamoyoung nope, no difference! you could strategically construct your batches to minimize padding tokens to maximize efficiency, but most practitioners never do so
Thanks! Do you know if other implementations tend to do this as well? In pytorch_geometric they allow for graph batching where the memory usage scales with the number of nodes/edges actually in the batch, not the maximum number of nodes/edges that are allowed in a single graph (which is analogous to the sequence length). They do this by implementing the attention with scatter/gather operations instead of masked matrix multiplications. I'm wondering if this would be a good idea for transformers, and if you know of anyone who has tried this.
@adamoyoung yea, the transformers community went a very different direction than that of graph neural nets and how it is approached with PyG. we typically don't do it the scatter/gather way, though I have met researchers who were interested in writing CUDA kernels to remove attention on the padding. i think batching by similar lengths is a good middle ground that i've seen used by others (one such implementation i came across https://github.com/jonathanking/sidechainnet/blob/4d4f57204c162ab938b8762dfacffb1d992774d0/sidechainnet/dataloaders/SimilarLengthBatchSampler.py#L9 )
Thanks, that's a good solution! Will check it out.