flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

enhance fla support for RWKV6

Open zhiyuan1i opened this issue 1 year ago • 14 comments

This pull request aims at enhance fla support for RWKV6, both speed and perfermance on bf16. Also , enable fla on Intel cards.

FLA ChunkRWKV6 Optimized Implementation

This repository contains an optimized implementation of ChunkRWKV6 using FLA (Flash Attention) techniques. Our goal is to simultaneously improve both accuracy and speed compared to standard CUDA implementations.

Performance Comparison

We've conducted performance tests comparing our FLA BF16 implementation with the standard CUDA BF16 implementation. Here are some key results:

Test Case 1: B=32, T=4096, C=4096, HEAD_SIZE=64

Implementation Forward Time Backward Time
CUDA BF16 32.80 ms 148.05 ms
FLA BF16 50.17 ms 162.42 ms

Test Case 2: B=8, T=4096, C=4096, HEAD_SIZE=64

Implementation Forward Time Backward Time
CUDA BF16 9.69 ms 46.41 ms
FLA BF16 13.06 ms 40.79 ms

Where:

  • B: Batch size
  • T: Token length
  • C: Hidden layer dimension
  • HEAD_SIZE: Size of attention heads

Accuracy

We've measured the error ratios compared to FP32 CUDA implementations for various components. Our chunkRWKV6 FLA implementation achieves error levels consistent with CUDA implementations:

y:  0.0020138283862787135
gr: 0.00250389610197927
gk: 0.002499128980485113
gv: 0.0028262425242107
gw: 0.0027358097395330894
gu: 0.001821853127644057

zhiyuan1i avatar Aug 13 '24 08:08 zhiyuan1i

Please try to squash merge :)

zhiyuan1i avatar Aug 13 '24 08:08 zhiyuan1i

@uniartisan Hello, many thanks for these great contributions! I will make some checks soon. However, could you restrict the revisions to the RWKV6 chunk only? You've defined many decorators for other purposes that are unrelated to this PR title. I think it would be better to create a separate PR for those changes. Additionally, please note that there are some formatting errors that are not aligned with PEP8 standards.

yzhangcs avatar Aug 13 '24 18:08 yzhangcs

https://github.com/sustcsonglin/flash-linear-attention/blob/8dea8bdaa14eb1f2a06152691dcd238043811fe6/tests/ops/test_rwkv6.py

This file seems broken

yzhangcs avatar Aug 13 '24 18:08 yzhangcs

Also it is not recommended to truncate the spaces at the end of each line in README file, as they are sometimes used as line breaks.

yzhangcs avatar Aug 13 '24 19:08 yzhangcs

Your suggestion makes a lot of sense. Some of these changes were introduced by the edittor. I'll try to first limit the changes to chunkrwkv6 and fix the test

zhiyuan1i avatar Aug 14 '24 05:08 zhiyuan1i

checkrwkv6.tar.gz Here are the codes that compare CUDA with FLA.

zhiyuan1i avatar Aug 14 '24 05:08 zhiyuan1i

Also, this pull request fixed https://github.com/sustcsonglin/flash-linear-attention/issues/29 The problem was introduced by bfloat16 when calculating dq and dk. By converting to float32 when necessary and using tf32 as much as possible, and changing the group sequence, the pull request speeds up and achieves the same accuracy as the CUDA implementation (pure fp32 internal).

zhiyuan1i avatar Aug 14 '24 07:08 zhiyuan1i

@uniartisan Hi, just make some reviews, could you have a check?

yzhangcs avatar Aug 14 '24 19:08 yzhangcs

@uniartisan Hi, just make some reviews, could you have a check?

hi. I can't see any comments, could you tell me where could I have a check?

zhiyuan1i avatar Aug 15 '24 05:08 zhiyuan1i

@uniartisan Can you see msgs in your notice box

yzhangcs avatar Aug 15 '24 09:08 yzhangcs

image Could you give me a review like this? https://github.com/sustcsonglin/flash-linear-attention/pull/44/files/4a3e2bb1d699c7e41ead7adc2f2403fb3e79ceb6

I can't see your msgs :(

zhiyuan1i avatar Aug 15 '24 11:08 zhiyuan1i

@uniartisan sure, sorry for my late reply

yzhangcs avatar Aug 18 '24 19:08 yzhangcs

@uniartisan Can you see my updated comments between the lines?

yzhangcs avatar Aug 18 '24 19:08 yzhangcs

@uniartisan Can you see my updated comments between the lines?

Sorry, I don't know what's going on. I still cannot see you review comments. Maybe you can directly post them here.😎

zhiyuan1i avatar Aug 20 '24 11:08 zhiyuan1i

@yzhangcs Hello, I hope finds you well. I have successfully synchronized all the latest changes to your project. Given your expertise and valuable insights, I was wondering if you could kindly take some time to review these updates at your earliest convenience. Your feedback is crucial to ensure we're on the right track, and I greatly appreciate your assistance in this matter. :)

zhiyuan1i avatar Aug 26 '24 06:08 zhiyuan1i

@uniartisan Thank you for the update. I'm running your code locally as there is no CI w/ GPUs. Will sync with you recently.

yzhangcs avatar Aug 26 '24 07:08 yzhangcs

@uniartisan Hi, can you authorize this branch to me so that I can make some updates

yzhangcs avatar Aug 30 '24 09:08 yzhangcs

Hi, can you authorize this branch to me so that I can make some updates

Of course!!! Sorry for my late reply. I will try it :)

zhiyuan1i avatar Sep 01 '24 15:09 zhiyuan1i

@uniartisan Hi, closing this PR as new features are too coupled. @sustcsonglin just pushed some new commits resolving the RWKV6 precision problems. Checkout those for more details. You can create new PRs if sth could be improved.

Again, thank you for your contributions and hard work!

yzhangcs avatar Sep 23 '24 18:09 yzhangcs