ring-flash-attention icon indicating copy to clipboard operation
ring-flash-attention copied to clipboard

Numerical errors in backward

Open grimulkan opened this issue 7 months ago • 2 comments

Were you able to find out the reason for the small numerical errors in backward pass with ring flash attention?

I found the errors increase as you increase the world size, so it does seem to be related to the fact that flash attention returns 16-bit tensors, and even though we accumulate in a 32-bit buffer it seems it is not enough.

Maybe it is an easy PR in flash attention to have them return raw fp32, or do the accumulation upstream?

grimulkan avatar Jun 27 '24 23:06 grimulkan