Jack Gallagher
Jack Gallagher
interesting that this works with `grad` outside of `scan` and `remat` - probably it should fail under `grad` alone without either of those?
with the fix it's working with `lq = lkv` under `jax.checkpoint`! still fails with `lq != lkv` which I'm trying to debug now
https://github.com/midjourney/flash-attention-jax/commit/f690412199178bc60fb4a768f28bffb2f27654cb
the error with `lq = 16; lkv = 17` is `TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).` full backtrace: ``` TypeError Traceback...
debugging a bit, it looks like the issue is that `dk` has shape `h, b, lkv, d` and `dk_chunk` has shape `h, b, lq, d`
@lucidrains looks like there's an implicit assumption somewhere in here that `lq == lkv` in the backwards pass, in `_query_chunk_flash_attention_backward`
does that work with `lq != lkv`?
looks like it does not
I am trying to use this with models I've already spent a decent amount of compute training, would be a lot more work to retrain from scratch
could of course tune with a constant scale but that seems like a worse option than relying on xla to fuse here since the non-cosine-sim version should be drop-in compatible.