flash-attention
flash-attention copied to clipboard
[Cute] Block sparse support Sm100
Summary
- Implement block-sparse attention in flash_fwd_sm100.py
- Update interface.py to handle SM100 block size calculations (2x multiplier for m_block_size since 1 CTA handles 2*tile_m rows)
- Add mask_mod parameter support in mask.py for block-sparse masking
- Add SM100 test fixtures and tile size handling in test_mask_mod.py
TODO before land: properly divmod the aux tensors
Also we should land: https://github.com/Dao-AILab/flash-attention/pull/1984 Before and rebase so its easier to review
Perf
Alot of perf wins (not universal for document mask ) but the delta from sol is much higher than what was found on hopper impl
Not autotuning the flex blocksparse impl gives this:
And autotuning the triton impl:
Possible problems
Looking at the Pm samples we can see a long tail:
For causal_mask with the default StaticPersistentSchedule. (We need to build a generic version of this) but we already have a better schedule for causal. If hard code the LPT schedule ![Uploading Screenshot 2025-11-04 at 5.32.58 PM.png…]()we go from :
to: