Enable FlashInfer support encoder models and add head_dim padding workaround
Motivation
This PR aims to enhance the FlashInfer attention backend in SGLang to address two primary goals:
- Enable support for encoder-only models: Currently, the FlashInfer backend needs adjustments to correctly handle non-causal attention required by encoder architectures.
- Resolve an "Invalid configuration" error for specific head dimensions: When using encoder models with certain head dimensions (e.g.,
head_dim=32as found inSupabase/gte-smalland potentially other BGE-like models) with FlashInfer's ragged prefill operations, an internal error is triggered, preventing these models from running.
The original issue is: #6050
Modifications
This PR introduces the following key changes:
1. Encoder Model Support (Non-Causal Attention):
- In
FlashInferAttnBackend.forward_extend, thecausalflag is now dynamically determined. For layers withlayer.attn_type == AttentionType.ENCODER_ONLY,causalis set toFalseto enable bidirectional (non-causal) attention. - For encoder self-attention,
save_kv_cacheis also appropriately set toFalseas KV caching across layers is typically not used in the same way as in decoders.
2. Workaround for FlashInfer head_dim Limitation (e.g., for head_dim=32):
FlashInfer currently fails when using BatchPrefillWithRaggedKVCacheWrapper with head_dim < 64 (e.g., 32). To work around this, we pad the head dimension up to 64 during prefill and forward steps:
- A global variable
global_fake_head_dim(default: 64) controls the padded size. - During prefill:
- If the model’s
head_dimis less thanglobal_fake_head_dim, we use the paddedfake_head_dimfor planning (begin_forward), but keepsm_scalebased on the originalhead_dimfor correctness.
- If the model’s
- During forward:
- Q, K, and V tensors are padded along the head dimension.
sm_scaleremains based on the originalhead_dim.- FlashInfer returns output with the padded size, which we truncate back to the original shape.
This workaround is temporary until native support for head_dim < 64 is available in FlashInfer.
3. Verification and Results:
The effectiveness of these changes, particularly the padding workaround for gte-small (or a similar model with head_dim=32), was verified by comparing the FlashInfer backend's output (final embedding logits, e.g., shape (10000, 768)) against Triton and a native PyTorch attention implementation (torch_native).
Numerical Similarity (vs torch_native for gte-small like model):
torch.allclose(rtol=0.01, atol=0.001):- FlashInfer: True
- Triton: True
torch.allclose(rtol=0.001, atol=0.0001):- FlashInfer: False
- Triton: False
- Mean Absolute Error (MAE):
- FlashInfer:
1.89077000e-05 - Triton:
1.78243699e-05
- FlashInfer:
- Maximum Absolute Error:
- FlashInfer:
9.76562500e-04 - Triton:
9.76562500e-04
- FlashInfer:
These results show that the padded FlashInfer backend achieves MAE on the order of ~1.8e-5 compared to the native PyTorch version, similar to Triton. The slightly larger maximum error and failure for tighter allclose tolerances are common for optimized kernels, especially with float16/bfloat16 dtypes, and are considered within acceptable limits.
Performance (seconds / 10,000 requests, for gte-small like model):
- FlashInfer (padded): 39.551 seconds
- Triton: 39.144 seconds
- Torch Native: 46.192 seconds
The padded FlashInfer backend demonstrates performance comparable to Triton and significantly improves over the native PyTorch implementation.
I'm open to discussing whether the current solution is appropriate. It might be better to remove the temporary workaround and retain only the causal check, especially if full FlashInfer support is expected soon.
That said, I'm so happy to keep the workaround in place while we wait for FlashInfer support to land. Thank you for taking the time to review this -- I'm open to any suggestions.
Checklist
- [X] Format your code according to the Code Formatting with Pre-Commit.
- [X] Add unit tests as outlined in the Running Unit Tests.
- [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
- [X] Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.