torchscale
torchscale copied to clipboard
Swapped naive dot product attention for flash attention
This pull request adds support for the Flash Attention mechanism to the MultiheadAttention module. Flash Attention is a recently proposed alternative to the conventional multi-head attention mechanism which reduces memory usage and improves training efficiency. The implementation in this pull request follows the paper "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (https://arxiv.org/abs/2205.14135)
Changes Made:
- Replaced the conventional multi-head attention mechanism with the Flash Attention mechanism in the forward() method.
- Added support for the key_padding_mask argument to the F.scaled_dot_product_attention() function.
- Updated the incremental_state logic to work with the new tensor shapes introduced by Flash Attention.
- Applied the xpos method to both the q and k tensors in the forward() method.
Replaced masked_fill with the additive mask to combine the attention mask and key padding mask.
Please review and merge.
@microsoft-github-policy-service agree
I ran into some issues using this branch as-is, and created a pull request for it here: https://github.com/usryokousha/torchscale/pull/1
Please review and pull in, if applicable.
Please merge with master
Please merge with master