flash-linear-attention icon indicating copy to clipboard operation
flash-linear-attention copied to clipboard

Current FLA RWKV6 implementation has significant precision issues in pure bf16 mode

Open howard-hou opened this issue 1 year ago • 1 comments

The current FLA RWKV6 implementation has significant precision issues in pure bf16 mode. Below are the results from my experiments:

CUDA bf16 (fp32 internal):

y: 0.0016603531206355376 gr: 0.0017877683404764239 gk: 0.0017853925508536652 gv: 0.0022316154634133964 gw: 0.0018482808625786967 gu: 0.0018472627187992381

FLA fp32:

y: 5.153093822969028e-07 gr: 5.860136550906496e-07 gk: 5.969486336398631e-07 gv: 5.833091583780125e-07 gw: 2.3036314788307143e-05 gu: 3.5015232226862115e-07

FLA bf16:

y: 0.0025760101921418134 gr: 0.0029575546041739134 gk: 0.002951189528185581 gv: 0.0031975613176225934 gw: 0.08319189127088046 (!!!) gu: 0.0017254238302962922

As shown, the FLA bf16 results show significantly larger errors, particularly for gw.

Please look into this precision issue.

Thank you

howard-hou avatar Jul 05 '24 10:07 howard-hou

@howard-hou Thanks for reporting this issue, we will have a check soon

yzhangcs avatar Jul 08 '24 14:07 yzhangcs

This issue is stale because it has been open for 30 days with no activity.

github-actions[bot] avatar Aug 25 '24 00:08 github-actions[bot]

This issue was closed because it has been inactive for 7 days since being marked as stale.

github-actions[bot] avatar Sep 01 '24 00:09 github-actions[bot]