flash-attention-jax icon indicating copy to clipboard operation
flash-attention-jax copied to clipboard

Return same dtype as is passed in in qkv

Open enolan opened this issue 1 year ago • 0 comments

As it stands, the attention functions return float32 regardless of the input dtype. This change makes them match it instead so you can use bf16 or fp16 or whatever. I've tested the causal variant and used it, but haven't tried the cosine similarity or non-causal versions. Caveat emptor.

enolan avatar Jul 05 '23 19:07 enolan