torchtune
torchtune copied to clipboard
FlashAttention for `TransformerDecoder`
I went through torchtune/modules/transformer.py but could not find anything about FlashAttention. Is it supported? Otherwise, it would be great to see it in the future. I'm using Llama 3.1 with very long contexts and I believe FlashAttention will save me significant memory.
Hey @NeuralFlux, our modules use PyTorch's nn.scaled_dot_product_attention which automatically selects the best backend for the attention calculation. In most cases, flash attention will be used if your hardware supports it and you're using bf16. If you pass in a non-causal mask then it will use memory-efficient attention (as is the case for sample packing).
If you're using our configs/recipes and are not modifying the modules or using sample packing, then you should be using flash attention.
That answers my question, thank you. Btw, I'm relatively new to this and want to know what non-causal mask exactly means. Where can I find more info?
Happy to explain.
Causal mask simply is a way to ensure that the next predicted token only depends on the previous token. In practice, this materializes as a lower triangular matrix:
good morning how are you
|====|======|===|===|====|
good │ ■ │ | | | |
morning │ ■ │ ■ | | | |
how │ ■ │ ■ | ■ | | |
are │ ■ │ ■ | ■ | ■ | |
you │ ■ │ ■ | ■ | ■ | ■ |
The filled square indicates that the token on the left "attends" to the token on the top. So you can see that "morning" only attends to "good morning" and "you" can attend to all of "good morning how are you". Without the causal mask, the transformer decoder can "cheat" and look-ahead at the rest of the sentence to predict the next token.
This is the default in most cases, but there are situations where you'd need to use a different mask (anything that is not the causal mask I'm referring to as non-causal). For example, in sample packing where you concatenate multiple unrelated sequences into a single sample, you only want each sequence to attend to itself. So you get a block triangular mask instead (see here for a great comparison). Or in the case of multimodal models, the mask used for cross-attention between images and masks does not follow the causal pattern either (see here for a depiction).
You can try this article for a deeper dive into the math behind this.
Hope this is helpful. Let me know if you have more questions.
Got it, thanks for the references too!