feat: added low VRAM flash attention backend
Motivation
The two existing attention backends both exhibit inefficiencies which inhibit the training experience.
sdpabackend materializes the fullbsz x num_heads x q_len x kv_lenattention score matrix in VRAM, severely inhibiting max sequence length.flex_attentionbackend is very particular about the linux environment and often requires different compilation flags depending on package versions. We were not able to get this kernel to compile reliably ontorch==2.8.0.
Using a log sum exp trick, we can avoid materializing any attention matrix while handling TTT KV cache with very minimal overhead. We support this using the flash attention backend since it readily provides us with an LSE tensor along with the O tensor. Flash attention 4 is also SOTA for training on Blackwell and while porting FA4 is out of scope of this PR, supporting the flash attention interface is a first step.
Modifications
Added a new LlamaFlashAttention module which has the same api as LlamaAttention (using a manual hidden cache).
Within the forward pass, we:
- Calculate the partial attention output with only the target model's KV cache using flash attention
- Create singleton partial attention outputs for each of the successive TTT iterations
- Combine all partials via weighted sum with their LSE's
Added a test file test_flash_attention.py which verifies equivalence with the SDPA backend (up to bf16 numerical stability).
Related Issues
Accuracy Test
Ran python -m tests.test_utils.test_flash_attention:
test_backward_pass_gradient_comparison (__main__.TestFlashAttention.test_backward_pass_gradient_comparison)
Test backward pass comparing gradients between LlamaAttention and LlamaFlashAttention. ... ok
test_forward_pass_comparison (__main__.TestFlashAttention.test_forward_pass_comparison)
Test forward pass comparison between LlamaAttention and LlamaFlashAttention. ... ok
----------------------------------------------------------------------
Ran 2 tests in 16.257s
OK
Benchmark & Profiling
Trained a speculator on custom data for GLM 4.5 on 8xH200 with batch size per GPU of 1 and sequence length of 32K. Here are the performance comparisons to flex attention:
| Method | VRAM Usage | Speed (s/it) |
|---|---|---|
| flex-attention | 888 GB | 9.5 |
| flash-attention | 854 GB | 7.2 |
We also trained for one epoch on perfectblend and achieved accept length of 3 on GSM8K with chain spec of 3 steps.
GLM 4.5 support was added in a custom branch built on top of this PR here.
Checklist
- [x] Format your code according to the Code Formatting with Pre-Commit.
- [x] Add unit tests as outlined in the Running Unit Tests.
- [x] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
- [x] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
- [x] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
- [x] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR.
[!WARNING] You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!
How much performance improvement does flex attention offer in comparison?
I ran a comparison on 8xH200 and added it to the benchmarks section. I had a slight improvement to flex attention in both VRAM usage and speed (25% faster). We could probably push it up further by supporting fa3 and fa4.
I was not able to get flex-attention to compile on B200, one of the core motivations for this feature.
Thanks! I was not able to use flex-attention on B200, too. Meanwhile, can you pre-commit your code?
There is still conflict with the main branch.
I trained qwen2.5-vl-7B-eagle3 using the latest specforge 0.1.1 and sglang 0.5.5, and encountered ”AttributeError: 'Qwen2_5_VLForConditionalGeneration' object has no attribute 'set_aux_hidden_states_layers'“. I didn't have this issue when using the version before the fix. What could be the reason?
@Abigbigbig This looks like a different issue from this PR. Let's move to a different issue. I can point you the fix