flash-linear-attention
flash-linear-attention copied to clipboard
enhance fla support for RWKV6
This pull request aims at enhance fla support for RWKV6, both speed and perfermance on bf16. Also , enable fla on Intel cards.
FLA ChunkRWKV6 Optimized Implementation
This repository contains an optimized implementation of ChunkRWKV6 using FLA (Flash Attention) techniques. Our goal is to simultaneously improve both accuracy and speed compared to standard CUDA implementations.
Performance Comparison
We've conducted performance tests comparing our FLA BF16 implementation with the standard CUDA BF16 implementation. Here are some key results:
Test Case 1: B=32, T=4096, C=4096, HEAD_SIZE=64
| Implementation | Forward Time | Backward Time |
|---|---|---|
| CUDA BF16 | 32.80 ms | 148.05 ms |
| FLA BF16 | 50.17 ms | 162.42 ms |
Test Case 2: B=8, T=4096, C=4096, HEAD_SIZE=64
| Implementation | Forward Time | Backward Time |
|---|---|---|
| CUDA BF16 | 9.69 ms | 46.41 ms |
| FLA BF16 | 13.06 ms | 40.79 ms |
Where:
- B: Batch size
- T: Token length
- C: Hidden layer dimension
- HEAD_SIZE: Size of attention heads
Accuracy
We've measured the error ratios compared to FP32 CUDA implementations for various components. Our chunkRWKV6 FLA implementation achieves error levels consistent with CUDA implementations:
y: 0.0020138283862787135
gr: 0.00250389610197927
gk: 0.002499128980485113
gv: 0.0028262425242107
gw: 0.0027358097395330894
gu: 0.001821853127644057
Please try to squash merge :)
@uniartisan Hello, many thanks for these great contributions! I will make some checks soon. However, could you restrict the revisions to the RWKV6 chunk only? You've defined many decorators for other purposes that are unrelated to this PR title. I think it would be better to create a separate PR for those changes. Additionally, please note that there are some formatting errors that are not aligned with PEP8 standards.
https://github.com/sustcsonglin/flash-linear-attention/blob/8dea8bdaa14eb1f2a06152691dcd238043811fe6/tests/ops/test_rwkv6.py
This file seems broken
Also it is not recommended to truncate the spaces at the end of each line in README file, as they are sometimes used as line breaks.
Your suggestion makes a lot of sense. Some of these changes were introduced by the edittor. I'll try to first limit the changes to chunkrwkv6 and fix the test
checkrwkv6.tar.gz Here are the codes that compare CUDA with FLA.
Also, this pull request fixed https://github.com/sustcsonglin/flash-linear-attention/issues/29 The problem was introduced by bfloat16 when calculating dq and dk. By converting to float32 when necessary and using tf32 as much as possible, and changing the group sequence, the pull request speeds up and achieves the same accuracy as the CUDA implementation (pure fp32 internal).
@uniartisan Hi, just make some reviews, could you have a check?
@uniartisan Hi, just make some reviews, could you have a check?
hi. I can't see any comments, could you tell me where could I have a check?
@uniartisan Can you see msgs in your notice box
Could you give me a review like this? https://github.com/sustcsonglin/flash-linear-attention/pull/44/files/4a3e2bb1d699c7e41ead7adc2f2403fb3e79ceb6
I can't see your msgs :(
@uniartisan sure, sorry for my late reply
@uniartisan Can you see my updated comments between the lines?
@uniartisan Can you see my updated comments between the lines?
Sorry, I don't know what's going on. I still cannot see you review comments. Maybe you can directly post them here.😎
@yzhangcs Hello, I hope finds you well. I have successfully synchronized all the latest changes to your project. Given your expertise and valuable insights, I was wondering if you could kindly take some time to review these updates at your earliest convenience. Your feedback is crucial to ensure we're on the right track, and I greatly appreciate your assistance in this matter. :)
@uniartisan Thank you for the update. I'm running your code locally as there is no CI w/ GPUs. Will sync with you recently.
@uniartisan Hi, can you authorize this branch to me so that I can make some updates
Hi, can you authorize this branch to me so that I can make some updates
Of course!!! Sorry for my late reply. I will try it :)
@uniartisan Hi, closing this PR as new features are too coupled. @sustcsonglin just pushed some new commits resolving the RWKV6 precision problems. Checkout those for more details. You can create new PRs if sth could be improved.
Again, thank you for your contributions and hard work!