x-transformers
x-transformers copied to clipboard
Any tips for speeding up generation?
Because of the autoregressive nature of Transformers, I know that they are fairly slow when generating new sequences from scratch, but I was wondering if you had any tips or tricks on how to do faster inference or to know if you had plans for maybe adding some of the tricks to avoid full computation, like the ones used by Huggingface https://huggingface.co/blog/accelerated-inference
Thank you very much for your amazing work!
@pabloppp Oh hey Pablo! are you using the repository in production? Yea I can make the inference fast (by adding caching of key / values, standard practice)
That would be awesome.
What I have tried to speed up the inference in my custom implementations for autoregressive self-attention is caching the output of the self-attention at timestep T and then, in timestep T+1, passing the full keys/values but only passing the last element of the query sequence, then getting the output and concatenating it with the cache, that way each query can pay attention to the full previous sequence but we don't need to compute attention for all the previous queries when we only need the output at T+1
It looks something like this:
But I only achieved a x3 speedup 🤔
I actually needed to perform autoregressive inference in a very large dataset, and it was taking more than 1 day even with the above speedup. I am currently doing some weird custom stuff, keeping the Transformer attention layers but replacing the self-attention layers with LSTMs, which are way faster at generating sequences token by token, and with that I achieve the x10 speedup that I needed.
@pabloppp the fastest speedup you'll get is to train a vanilla transformer, and then fine-tune it with Performer linear attention https://github.com/lucidrains/performer-pytorch that's probably the penultimate trick
What do you mean by 'fine-tune'? Training a vanilla transformer, then replacing the attention layers with performer attention layers and do some more training?
Yes exactly!
I will try that, thanks! Any idea about what could be the expected speedup?
In short, it will be as fast as if you had an RNN
https://arxiv.org/abs/2006.16236
@lucidrains thanks for your awesome work! Can you explain a bit, why not training performers from scratch, why you recommend to train vanilla and then finetune?
@stas-sl Performers scale very efficiently at longer sequence lengths (roughly 1500+), but they lose that advantage for short sequences. This is especially true for the softmax Performer, which is the version that's directly compatible with vanilla Transformers. For the softmax Performer, the constant costs of calculating the attention can cause it to be even slower than a Transformer during training. Hope that helps!
@stas-sl What Tom said :)
@pabloppp relevant to your interests https://arxiv.org/abs/2103.13076
Awesome, thanks!
@lucidrains > I can make the inference fast (by adding caching of key / values, standard practice)
Please can you help me how to make the inference fast (by adding caching of key / values, standard practice)?
I am curious why key value can be cached? Doesn't key and value is globally changed, except for the first decoder layer, after a new id is produced?