flash-attention-jax icon indicating copy to clipboard operation
flash-attention-jax copied to clipboard

Reshape error in causal_flash_attention when sequence length is not a multiple of 1024

Open enolan opened this issue 1 year ago • 0 comments

First off, thanks for writing this. It'd been a substantial improvement, even if the hand written CUDA kernels would've been better.

I've discovered a bug with odd sequence lengths. For e.g. 1025, you get TypeError: reshape total size must be unchanged, got new_sizes (1025, 256, 64) for shape (2, 1024, 256, 64). with a traceback pointing to causal_flash_attention.py:96 which is this line: out = out.reshape(q_len, bh, v_dim). AFAICT the problem occurs whenever your sequence length is greater than 1024 and not a multiple of 1024.

Repro:

import jax.numpy as jnp
from flash_attention_jax import causal_flash_attention

q = k = v = jnp.ones((1, 1, 1025, 16), dtype=jnp.float32)
_ = causal_flash_attention(q, k, v)

That fails, changing 1025 to 1024 works fine.

enolan avatar Dec 14 '23 01:12 enolan