flash-linear-attention
flash-linear-attention copied to clipboard
[Linear Attention] Update fused_recurrent.py for inference with nomalization=true
the current linear attention can save a $KV$ state cache. This works when normalization is not enabled. When normalization is enabled. the output should be $\frac{QKV}{QK1}$. we can see that $QK1$ or Q@sum(K) is missing earlier Keys
last pull request only modified one file, not sure why this happen, re-opened this, hope this version does contain two changes
Summary by CodeRabbit
- New Features
- Enhanced attention processing by adding an optional cumulative tensor input. This update refines the output normalization logic, offering increased flexibility and precision in the computation without altering the overall control flow.
[!IMPORTANT]
Review skipped
Draft detected.
Please check the settings in the CodeRabbit UI or the
.coderabbit.yamlfile in this repository. To trigger a single review, invoke the@coderabbitai reviewcommand.You can disable this status message by setting the
reviews.review_statustofalsein the CodeRabbit configuration file.
Walkthrough
The changes introduce an additional optional parameter in three functions to support cumulative tensor handling. The function fused_recurrent_linear_attn now accepts a tensor parameter cum_k (defaulting to None), and its call to normalize_output is updated accordingly. The normalize_output function also accepts an optional parameter cum_k and conditionally adds its value to k if provided. Similarly, the fused_chunk_linear_attn function is updated to include cum_k, affecting its normalization process. The changes adjust the function signatures and update internal computations without altering the overall control flow.
Changes
| File(s) | Change Summary |
|---|---|
| fla/ops/linear_attn/fused_recurrent.py | Modified fused_recurrent_linear_attn to add an optional cum_k parameter (torch.Tensor, default None) and updated its call to normalize_output. |
| fla/ops/linear_attn/utils.py | Updated normalize_output to include an optional cum_k parameter (default None) and conditionally add this value to k during normalization. |
| fla/ops/linear_attn/fused_chunk.py | Modified fused_chunk_linear_attn to add an optional cum_k parameter (torch.Tensor, default None) and updated its call to normalize_output. |
Sequence Diagram(s)
sequenceDiagram
participant FR as fused_recurrent_linear_attn
participant NO as normalize_output
FR->>NO: Call normalize_output(q * scale, k, o, cum_k)
alt cum_k provided
NO->>NO: Compute k = k + cum_k
else No cum_k
NO->>NO: Proceed without modifying k
end
NO->>FR: Return normalized output
participant FC as fused_chunk_linear_attn
FC->>NO: Call normalize_output(q * scale, k, o, cum_k)
alt cum_k provided
NO->>NO: Compute k = k + cum_k
else No cum_k
NO->>NO: Proceed without modifying k
end
NO->>FC: Return normalized output
Poem
I'm a rabbit with a hop and a bound,
Celebrating changes that are newly found.
A tiny tensor tip adds a little spark,
In our code garden, lighting up the dark.
With cum_k on board, our functions sing—
A joyful leap forward, oh what a spring!
🐰✨
🪧 Tips
Chat
There are 3 ways to chat with CodeRabbit:
- Review comments: Directly reply to a review comment made by CodeRabbit. Example:
I pushed a fix in commit <commit_id>, please review it.Generate unit testing code for this file.Open a follow-up GitHub issue for this discussion.
- Files and specific lines of code (under the "Files changed" tab): Tag
@coderabbitaiin a new review comment at the desired location with your query. Examples:@coderabbitai generate unit testing code for this file.@coderabbitai modularize this function.
- PR comments: Tag
@coderabbitaiin a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:@coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.@coderabbitai read src/utils.ts and generate unit testing code.@coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.@coderabbitai help me debug CodeRabbit configuration file.
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.
CodeRabbit Commands (Invoked using PR comments)
@coderabbitai pauseto pause the reviews on a PR.@coderabbitai resumeto resume the paused reviews.@coderabbitai reviewto trigger an incremental review. This is useful when automatic reviews are disabled for the repository.@coderabbitai full reviewto do a full review from scratch and review all the files again.@coderabbitai summaryto regenerate the summary of the PR.@coderabbitai generate docstringsto generate docstrings for this PR.@coderabbitai resolveresolve all the CodeRabbit review comments.@coderabbitai planto trigger planning for file edits and PR creation.@coderabbitai configurationto show the current CodeRabbit configuration for the repository.@coderabbitai helpto get help.
Other keywords and placeholders
- Add
@coderabbitai ignoreanywhere in the PR description to prevent this PR from being reviewed. - Add
@coderabbitai summaryto generate the high-level summary at a specific location in the PR description. - Add
@coderabbitaianywhere in the PR title to generate the title automatically.
CodeRabbit Configuration File (.coderabbit.yaml)
- You can programmatically configure CodeRabbit by adding a
.coderabbit.yamlfile to the root of your repository. - Please see the configuration documentation for more information.
- If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation:
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
Documentation and Community
- Visit our Documentation for detailed information on how to use CodeRabbit.
- Join our Discord Community to get help, request features, and share feedback.
- Follow us on X/Twitter for updates and announcements.
Thanks for contributing, can you @yiyousong add tests to your contribution? This will improve the robustness of the code
@yzhangcs could you please give some comments? tests: https://github.com/fla-org/flash-linear-attention/blob/main/tests/ops/test_linear_attn.py layers: https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/linear_attn.py
Thanks for contributing, can you @yiyousong add tests to your contribution? This will improve the robustness of the code
@yzhangcs could you please give some comments? tests: https://github.com/fla-org/flash-linear-attention/blob/main/tests/ops/test_linear_attn.py layers: https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/linear_attn.py
Sorry, I don't think I understand how the tests works.
Thanks for contributing, can you @yiyousong add tests to your contribution? This will improve the robustness of the code @yzhangcs could you please give some comments? tests: https://github.com/fla-org/flash-linear-attention/blob/main/tests/ops/test_linear_attn.py layers: https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/linear_attn.py
Sorry, I don't think I understand how the tests works.
You could have a try:)
pip install pytest
export COMPILER_MODE=1 # to speed up
pytest tests/ops/test_linear_attn.py
pytest tests/layers/test_linearatten_layer.py
You can see it will test function automatically. The thing you need to do is to test your cum_k and see if there is a need to change fla/layers/linear_attn.py because I see you only change the kernel.
@yiyousong Hello, could you please explain more on what does this arg mean and what's the purpose of imposing this arg
@yiyousong Hello, could you please explain more on what does this arg mean and what's the purpose of imposing this arg
Linear attention without normalization equals to $\phi(Q)\phi(K)V$ or $\Sigma (q_i \Sigma (k_j v_j))$. $\Sigma_{j=1}^N (k_j v_j)=\Sigma_{j=c}^N (k_j v_j)+\Sigma_{j=1}^c (k_j v_j)=\Sigma_{j=c}^N (k_j v_j)+cache$ this is your code when initial_state!=None The inputs $K$ should now contain tokens from position c to N (only tokens after cached positions)
Linear attention with normalization equals to $\frac{\phi(Q)\phi(K)V}{\phi(Q)\phi(K)1}$ or $\Sigma_{i=1}^N (\frac{q_i \Sigma (k_j v_j)}{q_i \Sigma k_j})$.
Focus on the $\Sigma k_j$. During generation, the inputs to the function is probably only just the last token (shorter than total when initial_state is not None). In this case the function will calculate the sum of K as $k_N$ which is significantly smaller than the sum of all Keys. To solve this, we need to cache the sum of the previous keys ($\Sigma_{i=1}^c k_i$). In this case the $\Sigma k_j$ can be recovered by add the cached sum ($\Sigma_{i=1}^c k_i$ ,shape [B,H,1,D]) to the cumulative sum of the current inputs ($\Sigma_{i=c}^N k_i$ ,shape [B,H,S,D]).
From the math, you can see that this issue only happens when both cache and normalization are used at the same time.
However, implementation was harder than I thought, as the compiled function does not take if statement. So I cannot just simply add a few parameters.
Maybe you need to change all the code involving normalization and cache
@yiyousong Thank you, good point! We do need to suuport this. But I dont think cum_k is a good name, some better APIs designs could be considered.
How about making normalization as a part of initial/final-state?
@yiyousong Hello r u still working on this PR?
@yiyousong Hello r u still working on this PR?
I was evaluating using my own code. (only used fla.ops, not fla.layers). Based on my experience, I believe although the "Z-state"(following naming from original linear attention) can be merged with the "memory-state", I would not recommend it, as "Z-state" is only needed when normalization is True. Changes on the lower-api is unavoidable.
These changes are based on the code I changed to work for my model. I probably won't work on this further. Maybe after May 15th I may continue to update this, but currently I am too busy to debug code I won't use.