jax icon indicating copy to clipboard operation
jax copied to clipboard

`fp32` or even `fp64` support for cudnn flashattention

Open wangleiphy opened this issue 1 year ago • 1 comments

Now, jax.nn.dot_product_attentiononly supports fp16 or bf16 when implementation='cudnn', see https://github.com/google/jax/blob/e3c4b20fa04893ad986c3184387fbd3817f1515d/jax/_src/cudnn/fused_attention_stablehlo.py#L240

I am wondering whether is it possible to support fp32 or even fp64, that would be great for some application which requires numerical accuracy.

wangleiphy avatar Sep 11 '24 04:09 wangleiphy

CC @Cjkkkk

superbobry avatar Sep 11 '24 08:09 superbobry

I would like to convey my strong support for FP32 flash attention. It holds great significance and would be extremely beneficial for my ongoing work!

guyuntian avatar Feb 11 '25 03:02 guyuntian

@guyuntian: do you need FP32 or is TF32 enough?

sbodenstein avatar Feb 12 '25 12:02 sbodenstein

@sbodenstein In my experiments, TF32 reduces overall performance compared to FP32. Depending on the specific task, sometimes TF32 with larger hidden dimension can perform better, but not always. It would be the best if flashattention supports both TF32 and FP32.

guyuntian avatar Feb 12 '25 13:02 guyuntian

What other requirements would you have for FP32/FP64, such as head dim, seq_len, etc...

mnicely avatar Feb 17 '25 15:02 mnicely

@mnicely In computational quantum physics, we typically set the head_dim to 64. The seq_len varies depending on the specific problem, but in most cases, we deal with relatively small seq_len values (≤ 256). Given these constraints, in FlashAttention implementations, we might even omit the online softmax.

guyuntian avatar Feb 18 '25 15:02 guyuntian