flash-attention-jax icon indicating copy to clipboard operation
flash-attention-jax copied to clipboard

support for per-head scales for cosine sim attention

Open GallagherCommaJack opened this issue 2 years ago • 6 comments

usually with cosine-sim models I'd train with learned per-head scales for the attention logits, I guess I can get this from multiplying by q & k by sqrt(scales) before the dot product but that's probably less stable

GallagherCommaJack avatar Sep 23 '22 16:09 GallagherCommaJack

@GallagherCommaJack try keeping it at a constant fixed scale of 10

it worked well for me as well as Boris (for Craiyon) In fact, it is his best run

lucidrains avatar Sep 24 '22 16:09 lucidrains

I am trying to use this with models I've already spent a decent amount of compute training, would be a lot more work to retrain from scratch

GallagherCommaJack avatar Sep 24 '22 16:09 GallagherCommaJack

could of course tune with a constant scale but that seems like a worse option than relying on xla to fuse here since the non-cosine-sim version should be drop-in compatible.

GallagherCommaJack avatar Sep 24 '22 16:09 GallagherCommaJack

@GallagherCommaJack ahh, i don't know if i can support that, i'm going all-in on fixed scale of 10 https://wandb.ai/dalle-mini/dalle-mini/reports/Fix-Swin-v2--VmlldzoyNDA4Mzc3 (blue curve)

lucidrains avatar Sep 24 '22 17:09 lucidrains

hmm really? the scales are just a pointwise op between the dot product and logits in a normal implementation. why does flash attention make that harder?

GallagherCommaJack avatar Sep 24 '22 17:09 GallagherCommaJack

@GallagherCommaJack it isn't difficult, it is just unnecessary

you can always fork it and add it yourself, if you need it for your specific pretrained network

cosine sim attention isn't even faithful to the original flash attention. it is kind of a new direction i'm taking it https://github.com/lucidrains/flash-cosine-sim-attention

lucidrains avatar Sep 24 '22 18:09 lucidrains