flash-attention-jax
flash-attention-jax copied to clipboard
Return same dtype as is passed in in qkv
As it stands, the attention functions return float32 regardless of the input dtype. This change makes them match it instead so you can use bf16 or fp16 or whatever. I've tested the causal variant and used it, but haven't tried the cosine similarity or non-causal versions. Caveat emptor.