if q,k is bf16 dtype,does q*kT output dtype need to be fp32
I have read the flash attention code;in the code,the qkT output dtype is fp32;if q,k is bf16 dtype,does qkT output dtype need to be fp32?In standard attention, q*kT output dtype is same as q, k input dtype;
We choose to have q @ K^T in fp32 for better numerical stability.
We choose to have q @ K^T in fp32 for better numerical stability. Thank you for your reply; In your previous reply(other issues), i see you define Numerical range (stability) as :do we have enough range to avoid Inf, NaN. I still have question,can we use output dtype bf16, as bf16‘s numerical range is same as fp32(just precision range is different); as my opinion, bf16 can also avoid Inf, NaN. @tridao
you can try that out