[Roadmap] FlashInfer v0.2 to v0.3
Milestones
Our tentative roadmap includes the following milestones:
- [ ] SageAttention-2 in FlashAttention3: Implement SageAttention-2 in FlashAttention3 template (#869 for fp8)
- [ ] Flex-Attention Compatible Interface: standarize JIT interface @shadowpa0327
- [ ] SM89 Kernel Optimization: Leverage Ada FP8 Tensor Cores for better performance on Ada6000 & 4090.
- [ ] Template Refactoring: Refactor FA-2 and MLA templates using CuTE.
- [x] MLA Acceleration: Optimize Multi-Level Attention (MLA) with Tensor Core support, follow up of #551 . Tracking issue #792
- [ ] Triton Porting: Migrate elementwise, normalization, and other kernels (that are not on critical path) to Triton.
- [ ] API Standardization: Simplify and standardize the attention APIs for better usability.
- [ ] POD-Attention Integration: Implement POD-Attention for improved efficiency of chunked-prefill. (#858 )
- [x] Nanoflow Parallelism: Expose python-level APIs for performing GEMM and Attention on a subset of SMs, which is required for nanoflow style parallelism, see #591.
- [ ] Fused Tree Speculative Sampling: follow up of #259 , we should support tree-speculative sampling as well, we will port the implementation of fused tree-speculative sampling written by @spectrometerHBH from https://github.com/mlc-ai/mlc-llm to accelerate eagle and medusa etc.
- [x] Improvements on Existing Top-P/K Sampling Operators: change the algorithm to guarantee all samples are successful after 32 rounds. #912
- [x] PyPI wheels: upload wheels to PyPI (pending issue: https://github.com/pypi/support/issues/5355)
- [ ] RoPE positions in batch attention interface: #701
- [ ] FP32 output for bf16 input: #696
- [ ] Per Head Scale Quantized KV-Cache: #721
We welcome your feedback and suggestions!
Let us know what features you'd like to see in FlashInfer.
Initial support blackwell: https://github.com/flashinfer-ai/flashinfer/pull/747 10.0 blackwell b100/b200 12.0 blackwell rtx50 super: flex attention
Looking forward to Pod-Attention support!
To add more context, we have the following piece of code in mneomsyne codebase:
def _arrange_sequences_for_execution(
self,
seq_schedule_metadata_list: List[SequenceScheduleMetadata],
) -> List[SequenceScheduleMetadata]:
"""
We need to arrange sequences in a way that allows us to perform
attention computation in an efficient manner. Due to poor handling of mixed batches
in attention kernels. We need to split the first split the sequences into prefill and decode:
| prefill seqs | decode seqs |
Secondly, when we mix sequences of different lengths, the attention kernel parallelization
heuristics fail, and results in high latency. Thus, we need to further split the sequences:
| long seqs | short seqs |
Furthermore, within each group, we can have kvp sequences. Some of these kvp
sequences might not require kv cache to be saved. So, within each group, we need to further
organize sequences as follows:
| non kvp seqs | kvp seqs w/ save_kv_cache | kvp seqs w/o save_kv_cache |
"""
In essence, we create make 4 different instances of flashinfer prefill attention wrapper and call the kernel 4 times 😢 cc @yzh119
Could POD-Attention potentially support the removal of prefill and decode batch scheduling logic, and instead just run all the decode and prefill requests together?
@Edenzzzz good idea, there is no reason to keep two set of APIs. Actually the current prefill attention can be used by decoding, just set the query length per request to 1.
We should use a unified BatchAttention API for all cases.
@yzh119 Thanks! I plan to try employing similar logic in SGLang this week.
I am newbie to this repo but experienced MLE. Is there anything I can contribute?
I am newbie to this repo but experienced MLE. Is there anything I can contribute?
Please take a look at our updated roadmap: #1770
Closing, since this roadmap is superseded by #1770