flash-attention
flash-attention copied to clipboard
Add custom ops for compatibility with PT Compile
This PR adds basic support for torch.compile() for the non-varlen variants of Flash Attention.
This essentially allows for models that use the flash_attn_qkvpacked_func
, flash_attn_kvpacked_func
, and flash_attn_func
to compile without graph breaks.
I can add unit tests if it makes sense. I'll test it with our own training pipeline for performance measurements and I'll post them later.
This uses the new custom operators API in Pytorch 2.4. I can move to the older APIs if needed, or I can look up how to make both coexist