Anique

Results 3 comments of Anique

It's because they replaced 'blacklist' with 'denylist' ...

It's supposed to be both faster and less memory intensive. lucidrains version is implemented in jax instead of at a low level. So it is slower than the normal version.

> `jax.experimental` has an [implementation](https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py) of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU. > > We can probably upstream this to Flax attention if...