flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Add custom ops for compatibility with PT Compile

Open ani300 opened this issue 6 months ago • 5 comments

This PR adds basic support for torch.compile() for the non-varlen variants of Flash Attention.

This essentially allows for models that use the flash_attn_qkvpacked_func, flash_attn_kvpacked_func, and flash_attn_func to compile without graph breaks.

I can add unit tests if it makes sense. I'll test it with our own training pipeline for performance measurements and I'll post them later.

This uses the new custom operators API in Pytorch 2.4. I can move to the older APIs if needed, or I can look up how to make both coexist

ani300 avatar Aug 08 '24 20:08 ani300