flashinfer icon indicating copy to clipboard operation
flashinfer copied to clipboard

[Feature] MiniMax-01 Lightning Attention

Open yzh119 opened this issue 10 months ago • 4 comments

The MiniMax-01 scales linear attention to large-scale model (456B) and FlashInfer should support it.

The prefill computation of the lightning attention (forward) can be summarized as:

Image

Image

The computation of O_intra of each tile is completely independent and we can just reuse our existing attention kernel by setting use_softmax=False in our attention variant class.

The computation of O_inter is basically a scan operation, we can either perform the entire loop per request within a CTA, or using split-K. In the second case, we split the N into chunks, we first compute the KV matrix of each chunk, compute the cumsum of KV, then compute the O_inter of all tiles independently. The split-k chunk size can be selected adaptively to strike a balance between the O_inter overhead (determined by number of chunks) and the O_intra computation overhead (determined by chunk size). KV should be kept in f32 precision considering the accumulation precision for long context.

For decode, there is no need to maintain KV-Cache in Page Table, we just need to keep one KV (dxd) matrix per request, and accumulating KV by Ki^T Vi for step i. It's still possible to maintain a unified page for softmax attention layers' KV-Cache and linear attention layers' KV, in that case, we can add gather gemm operators to flashinfer for O_inter computation.

yzh119 avatar Jan 16 '25 14:01 yzh119

Hi @yzh119, I am interested in working on this issue if no one is currently working on it. Could you assign this issue to me?

leifeng666 avatar Feb 01 '25 19:02 leifeng666

Hi @leifeng666 sure! I think a good starting point is to run benchmark on state-of-the-art triton implementation (such as flash-linear-attention) and see how far it's from speed-of-light.

yzh119 avatar Feb 02 '25 02:02 yzh119

@yzh119 at the first glance, seems like lightning attention haven't been supported in fla yet. I just created an issue to ask the fla community if they are going to implement that in the near future: https://github.com/fla-org/flash-linear-attention/issues/164. At the same time, I will also try to find another triton implementation and do a benchmark.

leifeng666 avatar Feb 05 '25 08:02 leifeng666

Checkout https://github.com/fla-org/flash-linear-attention/commit/f14178c233725a2484540bdc413fb1086de279cc Lightning Attention has been integrated into fla.

yzhangcs avatar Feb 05 '25 15:02 yzhangcs