Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Proposal] Optionally use flash attention.

Open tbenthompson opened this issue 1 year ago • 4 comments

It would be nice to have a flag to enable flash attention in models where that would make sense. This is helpful for performance and memory usage in larger models. In my case working with Pythia 12B, I get ~50% better performance and ~4x larger batch sizes when using flash attention. I also find numerical stability in float16 to be better using flash attention, probably because the model was trained using flash attention.

The downside of using flash attention in TransformerLens is that we would not have access to intermediate quantities in the attention calculation like the attention matrix itself. This is why I would suggest having a default-off flag so that users can choose whether they need those intermediate values to be available. In addition, when only a small subset of attention intermediates are needed, it's much faster to just cache the input to the attention layer (or the residual stream) and then recompute those intermediates when needed.

Thanks!

tbenthompson avatar Sep 08 '23 17:09 tbenthompson

Seems reasonable to me, I'd be happy for someone to add this

On Fri, 8 Sept 2023 at 18:44, Ben Thompson @.***> wrote:

It would be nice to have a flag to enable flash attention in models where that would make sense. This is helpful for performance and memory usage in larger models. In my case working with Pythia 12B, I get ~50% better performance and ~4x larger batch sizes when using flash attention. I also find numerical stability in float16 to be better using flash attention, probably because the model was trained using flash attention.

The downside of using flash attention in TransformerLens is that the we would not have access to intermediate quantities in the attention calculation like the attention matrix itself. This is why I would suggest having a default-off flag so that users can choose whether they need those intermediate values to be available. In addition, when only a small subset of attention intermediates are needed, it's much faster to just cache the input to the attention layer and then recompute those intermediates when needed.

Thanks!

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/378, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKOCL65GY7N52PH2PQDXZNKRXANCNFSM6AAAAAA4QWO7HE . You are receiving this because you are subscribed to this thread.Message ID: @.***>

neelnanda-io avatar Sep 08 '23 17:09 neelnanda-io

Seems v. useful for sparse autoencoder training.

Docs here - https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#conclusion - in case anyone wants to take this (I'll pick it up at some point if no-one does).

alan-cooney avatar Oct 28 '23 23:10 alan-cooney

I'd be quite keen to make a start on this soon, @alan-cooney have you made a start already?

cmathw avatar Jan 24 '24 08:01 cmathw

I'd be quite keen to make a start on this soon, @alan-cooney have you made a start already?

I haven't yet so please feel free to!

alan-cooney avatar Jan 24 '24 19:01 alan-cooney