x-transformers icon indicating copy to clipboard operation
x-transformers copied to clipboard

Any tips for speeding up generation?

Open pabloppp opened this issue 3 years ago • 14 comments

Because of the autoregressive nature of Transformers, I know that they are fairly slow when generating new sequences from scratch, but I was wondering if you had any tips or tricks on how to do faster inference or to know if you had plans for maybe adding some of the tricks to avoid full computation, like the ones used by Huggingface https://huggingface.co/blog/accelerated-inference

Thank you very much for your amazing work!

pabloppp avatar Mar 11 '21 09:03 pabloppp

@pabloppp Oh hey Pablo! are you using the repository in production? Yea I can make the inference fast (by adding caching of key / values, standard practice)

lucidrains avatar Mar 12 '21 15:03 lucidrains

That would be awesome. What I have tried to speed up the inference in my custom implementations for autoregressive self-attention is caching the output of the self-attention at timestep T and then, in timestep T+1, passing the full keys/values but only passing the last element of the query sequence, then getting the output and concatenating it with the cache, that way each query can pay attention to the full previous sequence but we don't need to compute attention for all the previous queries when we only need the output at T+1 It looks something like this: Captura de pantalla 2021-03-12 a las 17 01 32 But I only achieved a x3 speedup 🤔

I actually needed to perform autoregressive inference in a very large dataset, and it was taking more than 1 day even with the above speedup. I am currently doing some weird custom stuff, keeping the Transformer attention layers but replacing the self-attention layers with LSTMs, which are way faster at generating sequences token by token, and with that I achieve the x10 speedup that I needed.

pabloppp avatar Mar 12 '21 16:03 pabloppp

@pabloppp the fastest speedup you'll get is to train a vanilla transformer, and then fine-tune it with Performer linear attention https://github.com/lucidrains/performer-pytorch that's probably the penultimate trick

lucidrains avatar Mar 12 '21 16:03 lucidrains

What do you mean by 'fine-tune'? Training a vanilla transformer, then replacing the attention layers with performer attention layers and do some more training?

pabloppp avatar Mar 12 '21 16:03 pabloppp

Yes exactly!

lucidrains avatar Mar 12 '21 16:03 lucidrains

I will try that, thanks! Any idea about what could be the expected speedup?

pabloppp avatar Mar 12 '21 17:03 pabloppp

In short, it will be as fast as if you had an RNN

lucidrains avatar Mar 12 '21 17:03 lucidrains

https://arxiv.org/abs/2006.16236

lucidrains avatar Mar 12 '21 17:03 lucidrains

@lucidrains thanks for your awesome work! Can you explain a bit, why not training performers from scratch, why you recommend to train vanilla and then finetune?

stas-sl avatar Mar 27 '21 10:03 stas-sl

@stas-sl Performers scale very efficiently at longer sequence lengths (roughly 1500+), but they lose that advantage for short sequences. This is especially true for the softmax Performer, which is the version that's directly compatible with vanilla Transformers. For the softmax Performer, the constant costs of calculating the attention can cause it to be even slower than a Transformer during training. Hope that helps!

tomweingarten avatar Mar 28 '21 15:03 tomweingarten

@stas-sl What Tom said :)

@pabloppp relevant to your interests https://arxiv.org/abs/2103.13076

lucidrains avatar Mar 29 '21 19:03 lucidrains

Awesome, thanks!

pabloppp avatar Mar 29 '21 19:03 pabloppp

@lucidrains > I can make the inference fast (by adding caching of key / values, standard practice)

Please can you help me how to make the inference fast (by adding caching of key / values, standard practice)?

DLExpert avatar Apr 09 '21 00:04 DLExpert

I am curious why key value can be cached? Doesn't key and value is globally changed, except for the first decoder layer, after a new id is produced?

cloudhan avatar May 23 '22 03:05 cloudhan