flash-attention
flash-attention copied to clipboard
Return 32-bit for external accumulation
Would it be possible to return higher precision tensors where relevant, as an option, to allow users to break apart attention computation in blocks?
For example, in ring + flash attention (https://github.com/zhuzilin/ring-flash-attention), the computation is broken up over a bunch of GPUs each calling the flash attention kernel, and exchange/accumulate the results outside flash-attn. But since the returns are all 16-bit, even accumulating in a 32-bit buffer seems to be not as accurate as calling flash-attn in a single call.
I'm guessing flash-attn internally accumulates at higher precision and downcasts for the return? If so, could it be possible to return the raw full precision tensors for such purposes. Assuming they exist in the internal implementation, of course. If everything is truly 16-bit, maybe there is some other reason.
This would apply for both forward and backward calls.
https://github.com/zhuzilin/ring-flash-attention/issues/42