Jack Gallagher

Results 32 comments of Jack Gallagher

interesting that this works with `grad` outside of `scan` and `remat` - probably it should fail under `grad` alone without either of those?

with the fix it's working with `lq = lkv` under `jax.checkpoint`! still fails with `lq != lkv` which I'm trying to debug now

https://github.com/midjourney/flash-attention-jax/commit/f690412199178bc60fb4a768f28bffb2f27654cb

the error with `lq = 16; lkv = 17` is `TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).` full backtrace: ``` TypeError Traceback...

debugging a bit, it looks like the issue is that `dk` has shape `h, b, lkv, d` and `dk_chunk` has shape `h, b, lq, d`

@lucidrains looks like there's an implicit assumption somewhere in here that `lq == lkv` in the backwards pass, in `_query_chunk_flash_attention_backward`

does that work with `lq != lkv`?

I am trying to use this with models I've already spent a decent amount of compute training, would be a lot more work to retrain from scratch

could of course tune with a constant scale but that seems like a worse option than relying on xla to fuse here since the non-cosine-sim version should be drop-in compatible.