`fp32` or even `fp64` support for cudnn flashattention
Now, jax.nn.dot_product_attentiononly supports fp16 or bf16 when implementation='cudnn', see
https://github.com/google/jax/blob/e3c4b20fa04893ad986c3184387fbd3817f1515d/jax/_src/cudnn/fused_attention_stablehlo.py#L240
I am wondering whether is it possible to support fp32 or even fp64, that would be great for some application which requires numerical accuracy.
CC @Cjkkkk
I would like to convey my strong support for FP32 flash attention. It holds great significance and would be extremely beneficial for my ongoing work!
@guyuntian: do you need FP32 or is TF32 enough?
@sbodenstein In my experiments, TF32 reduces overall performance compared to FP32. Depending on the specific task, sometimes TF32 with larger hidden dimension can perform better, but not always. It would be the best if flashattention supports both TF32 and FP32.
What other requirements would you have for FP32/FP64, such as head dim, seq_len, etc...
@mnicely In computational quantum physics, we typically set the head_dim to 64. The seq_len varies depending on the specific problem, but in most cases, we deal with relatively small seq_len values (≤ 256). Given these constraints, in FlashAttention implementations, we might even omit the online softmax.