MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

Enable relative positional embedding in flash attention

Open KumoLiu opened this issue 1 year ago • 1 comments

From reading this thread: https://github.com/pytorch/pytorch/issues/96099#issuecomment-1480430583 It seems to me that the relative positional embedding can be integrated with scaled_dot_product_attention 's attn_mask argument. However, it can be slow as it's not taking the "fast path".

Do you think we can keep this option open for users who wants to use flash_attention and rel_pos_embedding?

Originally posted by @mingxin-zheng in https://github.com/Project-MONAI/MONAI/pull/7977#discussion_r1701825032

KumoLiu avatar Aug 06 '24 15:08 KumoLiu

I would think that https://github.com/Dao-AILab/flash-attention/pull/617 needs to be completed for FAv2 support for arbitrary attention bias. And then depending on actual needed relative encoding formula, maybe https://github.com/Dao-AILab/flash-attention/pull/956 could be pushed

Another way forward is trying PyTorch's flex_attention which can fuse modification of attention matrix

vadimkantorov avatar Aug 31 '24 09:08 vadimkantorov