torchscale icon indicating copy to clipboard operation
torchscale copied to clipboard

Swapped naive dot product attention for flash attention

Open usryokousha opened this issue 1 year ago • 4 comments

This pull request adds support for the Flash Attention mechanism to the MultiheadAttention module. Flash Attention is a recently proposed alternative to the conventional multi-head attention mechanism which reduces memory usage and improves training efficiency. The implementation in this pull request follows the paper "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (https://arxiv.org/abs/2205.14135)

Changes Made:

  • Replaced the conventional multi-head attention mechanism with the Flash Attention mechanism in the forward() method.
  • Added support for the key_padding_mask argument to the F.scaled_dot_product_attention() function.
  • Updated the incremental_state logic to work with the new tensor shapes introduced by Flash Attention.
  • Applied the xpos method to both the q and k tensors in the forward() method. Replaced masked_fill with the additive mask to combine the attention mask and key padding mask.
    Please review and merge.

usryokousha avatar Mar 31 '23 02:03 usryokousha

@microsoft-github-policy-service agree

usryokousha avatar Mar 31 '23 03:03 usryokousha

I ran into some issues using this branch as-is, and created a pull request for it here: https://github.com/usryokousha/torchscale/pull/1

Please review and pull in, if applicable.

mranzinger avatar Apr 27 '23 09:04 mranzinger

Please merge with master

usryokousha avatar Jun 20 '23 10:06 usryokousha

Please merge with master

usryokousha avatar Jun 20 '23 10:06 usryokousha