flash-attention-jax
flash-attention-jax copied to clipboard
Reshape error in causal_flash_attention when sequence length is not a multiple of 1024
First off, thanks for writing this. It'd been a substantial improvement, even if the hand written CUDA kernels would've been better.
I've discovered a bug with odd sequence lengths. For e.g. 1025, you get TypeError: reshape total size must be unchanged, got new_sizes (1025, 256, 64) for shape (2, 1024, 256, 64).
with a traceback pointing to causal_flash_attention.py:96
which is this line: out = out.reshape(q_len, bh, v_dim)
. AFAICT the problem occurs whenever your sequence length is greater than 1024 and not a multiple of 1024.
Repro:
import jax.numpy as jnp
from flash_attention_jax import causal_flash_attention
q = k = v = jnp.ones((1, 1, 1025, 16), dtype=jnp.float32)
_ = causal_flash_attention(q, k, v)
That fails, changing 1025 to 1024 works fine.