Anique
Results
3
comments of
Anique
It's because they replaced 'blacklist' with 'denylist' ...
It's supposed to be both faster and less memory intensive. lucidrains version is implemented in jax instead of at a low level. So it is slower than the normal version.
> `jax.experimental` has an [implementation](https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py) of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU. > > We can probably upstream this to Flax attention if...