sglang icon indicating copy to clipboard operation
sglang copied to clipboard

Enable FlashInfer support encoder models and add head_dim padding workaround

Open ccs96307 opened this issue 7 months ago • 0 comments

Motivation

This PR aims to enhance the FlashInfer attention backend in SGLang to address two primary goals:

  1. Enable support for encoder-only models: Currently, the FlashInfer backend needs adjustments to correctly handle non-causal attention required by encoder architectures.
  2. Resolve an "Invalid configuration" error for specific head dimensions: When using encoder models with certain head dimensions (e.g., head_dim=32 as found in Supabase/gte-small and potentially other BGE-like models) with FlashInfer's ragged prefill operations, an internal error is triggered, preventing these models from running.

The original issue is: #6050

Modifications

This PR introduces the following key changes:

1. Encoder Model Support (Non-Causal Attention):

  • In FlashInferAttnBackend.forward_extend, the causal flag is now dynamically determined. For layers with layer.attn_type == AttentionType.ENCODER_ONLY, causal is set to False to enable bidirectional (non-causal) attention.
  • For encoder self-attention, save_kv_cache is also appropriately set to False as KV caching across layers is typically not used in the same way as in decoders.

2. Workaround for FlashInfer head_dim Limitation (e.g., for head_dim=32): FlashInfer currently fails when using BatchPrefillWithRaggedKVCacheWrapper with head_dim < 64 (e.g., 32). To work around this, we pad the head dimension up to 64 during prefill and forward steps:

  • A global variable global_fake_head_dim (default: 64) controls the padded size.
  • During prefill:
    • If the model’s head_dim is less than global_fake_head_dim, we use the padded fake_head_dim for planning (begin_forward), but keep sm_scale based on the original head_dim for correctness.
  • During forward:
    • Q, K, and V tensors are padded along the head dimension.
    • sm_scale remains based on the original head_dim.
    • FlashInfer returns output with the padded size, which we truncate back to the original shape.

This workaround is temporary until native support for head_dim < 64 is available in FlashInfer.

3. Verification and Results: The effectiveness of these changes, particularly the padding workaround for gte-small (or a similar model with head_dim=32), was verified by comparing the FlashInfer backend's output (final embedding logits, e.g., shape (10000, 768)) against Triton and a native PyTorch attention implementation (torch_native).

Numerical Similarity (vs torch_native for gte-small like model):

  • torch.allclose (rtol=0.01, atol=0.001):
    • FlashInfer: True
    • Triton: True
  • torch.allclose (rtol=0.001, atol=0.0001):
    • FlashInfer: False
    • Triton: False
  • Mean Absolute Error (MAE):
    • FlashInfer: 1.89077000e-05
    • Triton: 1.78243699e-05
  • Maximum Absolute Error:
    • FlashInfer: 9.76562500e-04
    • Triton: 9.76562500e-04

These results show that the padded FlashInfer backend achieves MAE on the order of ~1.8e-5 compared to the native PyTorch version, similar to Triton. The slightly larger maximum error and failure for tighter allclose tolerances are common for optimized kernels, especially with float16/bfloat16 dtypes, and are considered within acceptable limits.

Performance (seconds / 10,000 requests, for gte-small like model):

  • FlashInfer (padded): 39.551 seconds
  • Triton: 39.144 seconds
  • Torch Native: 46.192 seconds

The padded FlashInfer backend demonstrates performance comparable to Triton and significantly improves over the native PyTorch implementation.


I'm open to discussing whether the current solution is appropriate. It might be better to remove the temporary workaround and retain only the causal check, especially if full FlashInfer support is expected soon.

That said, I'm so happy to keep the workaround in place while we wait for FlashInfer support to land. Thank you for taking the time to review this -- I'm open to any suggestions.

Checklist

  • [X] Format your code according to the Code Formatting with Pre-Commit.
  • [X] Add unit tests as outlined in the Running Unit Tests.
  • [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • [X] Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.

ccs96307 avatar May 12 '25 10:05 ccs96307