SpecForge icon indicating copy to clipboard operation
SpecForge copied to clipboard

feat: added low VRAM flash attention backend

Open timmy-feng opened this issue 1 month ago • 7 comments

Motivation

The two existing attention backends both exhibit inefficiencies which inhibit the training experience.

  • sdpa backend materializes the full bsz x num_heads x q_len x kv_len attention score matrix in VRAM, severely inhibiting max sequence length.
  • flex_attention backend is very particular about the linux environment and often requires different compilation flags depending on package versions. We were not able to get this kernel to compile reliably on torch==2.8.0.

Using a log sum exp trick, we can avoid materializing any attention matrix while handling TTT KV cache with very minimal overhead. We support this using the flash attention backend since it readily provides us with an LSE tensor along with the O tensor. Flash attention 4 is also SOTA for training on Blackwell and while porting FA4 is out of scope of this PR, supporting the flash attention interface is a first step.

Modifications

Added a new LlamaFlashAttention module which has the same api as LlamaAttention (using a manual hidden cache).

Within the forward pass, we:

  • Calculate the partial attention output with only the target model's KV cache using flash attention
  • Create singleton partial attention outputs for each of the successive TTT iterations
  • Combine all partials via weighted sum with their LSE's

Added a test file test_flash_attention.py which verifies equivalence with the SDPA backend (up to bf16 numerical stability).

Related Issues

Accuracy Test

Ran python -m tests.test_utils.test_flash_attention:

test_backward_pass_gradient_comparison (__main__.TestFlashAttention.test_backward_pass_gradient_comparison)
Test backward pass comparing gradients between LlamaAttention and LlamaFlashAttention. ... ok
test_forward_pass_comparison (__main__.TestFlashAttention.test_forward_pass_comparison)
Test forward pass comparison between LlamaAttention and LlamaFlashAttention. ... ok

----------------------------------------------------------------------
Ran 2 tests in 16.257s

OK

Benchmark & Profiling

Trained a speculator on custom data for GLM 4.5 on 8xH200 with batch size per GPU of 1 and sequence length of 32K. Here are the performance comparisons to flex attention:

Method VRAM Usage Speed (s/it)
flex-attention 888 GB 9.5
flash-attention 854 GB 7.2

We also trained for one epoch on perfectblend and achieved accept length of 3 on GSM8K with chain spec of 3 steps.

GLM 4.5 support was added in a custom branch built on top of this PR here.

Checklist

  • [x] Format your code according to the Code Formatting with Pre-Commit.
  • [x] Add unit tests as outlined in the Running Unit Tests.
  • [x] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • [x] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • [x] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • [x] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR.

timmy-feng avatar Nov 20 '25 06:11 timmy-feng

[!WARNING] You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

gemini-code-assist[bot] avatar Nov 20 '25 06:11 gemini-code-assist[bot]

How much performance improvement does flex attention offer in comparison?

sleepcoo avatar Nov 20 '25 07:11 sleepcoo

I ran a comparison on 8xH200 and added it to the benchmarks section. I had a slight improvement to flex attention in both VRAM usage and speed (25% faster). We could probably push it up further by supporting fa3 and fa4.

I was not able to get flex-attention to compile on B200, one of the core motivations for this feature.

timmy-feng avatar Nov 20 '25 22:11 timmy-feng

Thanks! I was not able to use flex-attention on B200, too. Meanwhile, can you pre-commit your code?

FrankLeeeee avatar Nov 21 '25 06:11 FrankLeeeee

There is still conflict with the main branch.

FrankLeeeee avatar Nov 23 '25 14:11 FrankLeeeee

I trained qwen2.5-vl-7B-eagle3 using the latest specforge 0.1.1 and sglang 0.5.5, and encountered ”AttributeError: 'Qwen2_5_VLForConditionalGeneration' object has no attribute 'set_aux_hidden_states_layers'“. I didn't have this issue when using the version before the fix. What could be the reason?

Abigbigbig avatar Nov 28 '25 06:11 Abigbigbig

@Abigbigbig This looks like a different issue from this PR. Let's move to a different issue. I can point you the fix

yubofredwang avatar Nov 28 '25 07:11 yubofredwang