[Feature Request] Cached inference for native sparse attention
Feature Request
It seems that the current version of native sparse attention is in lack of the capability of cached inference. When use_cache=True is passed with model.generate, a shape mismatch error will occur silently as the current parallel_nsa function simply doesn't consider the case when q and k,v have different lengths.
Motivation
It will be great if the cached inference can be supported. This will allow the NSA implementation in FLA to be usable and valuable for actual generation, and facilitate the use and evaluation of NSA.
Your Contribution
I'm glad to help with testing.
I'm actually using the NSA kernel recently and hence working on fixing this...I can try to get this done in a few days
BTW @Espere-1119-Song from what I understand it is not enough to simply modify T to Tq to fix this, as difference in Tq and Tkv will also break block selection, causal mask, etc.