flash-linear-attention
flash-linear-attention copied to clipboard
Current FLA RWKV6 implementation has significant precision issues in pure bf16 mode
The current FLA RWKV6 implementation has significant precision issues in pure bf16 mode. Below are the results from my experiments:
CUDA bf16 (fp32 internal):
y: 0.0016603531206355376 gr: 0.0017877683404764239 gk: 0.0017853925508536652 gv: 0.0022316154634133964 gw: 0.0018482808625786967 gu: 0.0018472627187992381
FLA fp32:
y: 5.153093822969028e-07 gr: 5.860136550906496e-07 gk: 5.969486336398631e-07 gv: 5.833091583780125e-07 gw: 2.3036314788307143e-05 gu: 3.5015232226862115e-07
FLA bf16:
y: 0.0025760101921418134 gr: 0.0029575546041739134 gk: 0.002951189528185581 gv: 0.0031975613176225934 gw: 0.08319189127088046 (!!!) gu: 0.0017254238302962922
As shown, the FLA bf16 results show significantly larger errors, particularly for gw.
Please look into this precision issue.
Thank you
@howard-hou Thanks for reporting this issue, we will have a check soon
This issue is stale because it has been open for 30 days with no activity.
This issue was closed because it has been inactive for 7 days since being marked as stale.