ring-flash-attention
ring-flash-attention copied to clipboard
Numerical errors in backward
Were you able to find out the reason for the small numerical errors in backward pass with ring flash attention?
I found the errors increase as you increase the world size, so it does seem to be related to the fact that flash attention returns 16-bit tensors, and even though we accumulate in a 32-bit buffer it seems it is not enough.
Maybe it is an easy PR in flash attention to have them return raw fp32, or do the accumulation upstream?