flashinfer
flashinfer copied to clipboard
[Feature] MiniMax-01 Lightning Attention
The MiniMax-01 scales linear attention to large-scale model (456B) and FlashInfer should support it.
The prefill computation of the lightning attention (forward) can be summarized as:
The computation of O_intra of each tile is completely independent and we can just reuse our existing attention kernel by setting use_softmax=False in our attention variant class.
The computation of O_inter is basically a scan operation, we can either perform the entire loop per request within a CTA, or using split-K. In the second case, we split the N into chunks, we first compute the KV matrix of each chunk, compute the cumsum of KV, then compute the O_inter of all tiles independently. The split-k chunk size can be selected adaptively to strike a balance between the O_inter overhead (determined by number of chunks) and the O_intra computation overhead (determined by chunk size). KV should be kept in f32 precision considering the accumulation precision for long context.
For decode, there is no need to maintain KV-Cache in Page Table, we just need to keep one KV (dxd) matrix per request, and accumulating KV by Ki^T Vi for step i. It's still possible to maintain a unified page for softmax attention layers' KV-Cache and linear attention layers' KV, in that case, we can add gather gemm operators to flashinfer for O_inter computation.
Hi @yzh119, I am interested in working on this issue if no one is currently working on it. Could you assign this issue to me?
Hi @leifeng666 sure! I think a good starting point is to run benchmark on state-of-the-art triton implementation (such as flash-linear-attention) and see how far it's from speed-of-light.
@yzh119 at the first glance, seems like lightning attention haven't been supported in fla yet. I just created an issue to ask the fla community if they are going to implement that in the near future: https://github.com/fla-org/flash-linear-attention/issues/164. At the same time, I will also try to find another triton implementation and do a benchmark.
Checkout https://github.com/fla-org/flash-linear-attention/commit/f14178c233725a2484540bdc413fb1086de279cc
Lightning Attention has been integrated into fla.