flash-attention-jax icon indicating copy to clipboard operation
flash-attention-jax copied to clipboard

more general mask support

Open GallagherCommaJack opened this issue 2 years ago • 1 comments

the general case of attention is (using annotations from jaxtyping)

q: Float["lq d"]
k: Float["lkv d"]
v: Float["lkv o"]
mask: Bool["lq lkv"]

returns: Float["lq o"]

but it looks like right now this library only supports a 1 dimensional mask?

GallagherCommaJack avatar Sep 20 '22 04:09 GallagherCommaJack

@GallagherCommaJack at the moment, it only supports key masking

generalized 2d masking would defeat the purpose of flash attention, as you incur the quadratic cost

lucidrains avatar Sep 21 '22 16:09 lucidrains