flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

[RFC] Increase computation intensity for certain kernels

Open sustcsonglin opened this issue 8 months ago • 1 comments

Proposal

The current chunk mode normally loads 64x64 blocks, do the computation, and then save the resulting hidden state, which could bring I/O burden. In Tri Dao's Mamba2 implementation and xLSTM's chunkwise implementation, they load several 64x64 blocks to save a hidden state every 128 or 256 length, which reduces the I/O cost of saving hidden state and increasing the memory intensity. In my previous preliminary experiments, this could result in a not-small improvement. We'd want to change some kernels rich in matmul to this strategy, like in simple-gla's and deltanet's chunk kernel.

Rationale

sustcsonglin avatar Feb 17 '25 06:02 sustcsonglin