Linear Attention
Hi,
Thank you for your great work! It's really helpful in my research.
I'm interested in using NATTEN with linear attention, which can be simplified as (q@k) @ v -> q@(k@v). This approach may further reduce the complexity.
In the current version of NATTEN, the attention map generated by na2d_qk() has a shape of [B, H, W, K^2] (ignoring heads). However, in linear attention, the matrix multiplication between k and v is performed first, resulting in an attention map with a shape of k[B, H, W, C, K^2] @ v[B, H, W, K^2, C] -> attn[B, H, W, C, C]. The output is then computed as q[B, H, W, C] @ attn[B, H, W, C, C] -> output[B, H, W, C].
I would like to request that this linear attention feature be added to NATTEN, such as na2d_kv() and na2d_qa(). May I know if this is possible? Thanks a lot!
Thank you for your interest.
Could you please clarify your definition of linear NA? I'm not sure I understand it. NA is already linear.
If you're referencing transposed attention, that's still quadratic w.r.t. channels/dim and instead of paying attention along one mode/axis, you're changing it to another. Not sure if NA even makes sense in that context.
Closing due to inactivity. Feel free to reopen if you still have questions.