flash-attention-jax
flash-attention-jax copied to clipboard
more general mask support
the general case of attention is (using annotations from jaxtyping)
q: Float["lq d"]
k: Float["lkv d"]
v: Float["lkv o"]
mask: Bool["lq lkv"]
returns: Float["lq o"]
but it looks like right now this library only supports a 1 dimensional mask?
@GallagherCommaJack at the moment, it only supports key masking
generalized 2d masking would defeat the purpose of flash attention, as you incur the quadratic cost