flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

[Cute] Block sparse support Sm100

Open drisspg opened this issue 3 weeks ago • 0 comments

Summary

  • Implement block-sparse attention in flash_fwd_sm100.py
  • Update interface.py to handle SM100 block size calculations (2x multiplier for m_block_size since 1 CTA handles 2*tile_m rows)
  • Add mask_mod parameter support in mask.py for block-sparse masking
  • Add SM100 test fixtures and tile size handling in test_mask_mod.py

TODO before land: properly divmod the aux tensors

Also we should land: https://github.com/Dao-AILab/flash-attention/pull/1984 Before and rebase so its easier to review

Perf

Alot of perf wins (not universal for document mask ) but the delta from sol is much higher than what was found on hopper impl

Not autotuning the flex blocksparse impl gives this: combined_comparison

And autotuning the triton impl: combined_comparison_autotune

Possible problems

Looking at the Pm samples we can see a long tail:

For causal_mask with the default StaticPersistentSchedule. (We need to build a generic version of this) but we already have a better schedule for causal. If hard code the LPT schedule ![Uploading Screenshot 2025-11-04 at 5.32.58 PM.png…]()

we go from : Screenshot 2025-11-04 at 5 39 33 PM

to:

Screenshot 2025-11-04 at 5 39 14 PM

Tests

Screenshot 2025-11-10 at 8 26 43 PM Screenshot 2025-11-10 at 8 37 51 PM Screenshot 2025-11-10 at 8 43 15 PM

drisspg avatar Nov 05 '25 00:11 drisspg