flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

[Feature Request] Cached inference for native sparse attention

Open mutiann opened this issue 10 months ago • 1 comments

Feature Request

It seems that the current version of native sparse attention is in lack of the capability of cached inference. When use_cache=True is passed with model.generate, a shape mismatch error will occur silently as the current parallel_nsa function simply doesn't consider the case when q and k,v have different lengths.

Motivation

It will be great if the cached inference can be supported. This will allow the NSA implementation in FLA to be usable and valuable for actual generation, and facilitate the use and evaluation of NSA.

Your Contribution

I'm glad to help with testing.

mutiann avatar May 30 '25 22:05 mutiann

I'm actually using the NSA kernel recently and hence working on fixing this...I can try to get this done in a few days

BTW @Espere-1119-Song from what I understand it is not enough to simply modify T to Tq to fix this, as difference in Tq and Tkv will also break block selection, causal mask, etc.

mutiann avatar Aug 13 '25 14:08 mutiann