flax icon indicating copy to clipboard operation
flax copied to clipboard

Feature request: FlashAttention

Open OhadRubin opened this issue 1 year ago • 15 comments

Hey, There is an implementation for this memory efficient version of attention here. And I was wondering if it is possible to somehow integrate it into Flax.

OhadRubin avatar Feb 09 '23 15:02 OhadRubin

~~Instead of flash attention, maybe consider memory efficient attention? https://arxiv.org/abs/2112.05682~~

~~ I tested that flash attention and noticed it's both slower and seem to take more memory. Tests were done on TPU, so it's possible the higher memory usage is due to flash attention slicing causing more padding. ~~

~~Memory efficient attention have shown a decrease in memory. The paper actually comes with jax implementation. Google research's implementation can be found https://github.com/google-research/google-research/tree/master/memory_efficient_attention~~

The new Jax Pallas grant access to lower level optimization. Flash attention implementation should now be faster and more efficient.

Lime-Cakes avatar Mar 31 '23 17:03 Lime-Cakes

It's supposed to be both faster and less memory intensive. lucidrains version is implemented in jax instead of at a low level. So it is slower than the normal version.

aniquetahir avatar Jul 02 '23 14:07 aniquetahir

+1 Any update on this requst ?

agemagician avatar Oct 13 '23 17:10 agemagician

jax.experimental has an implementation of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU.

We can probably upstream this to Flax attention if jax.experimental doesn't scare us (it means the API could still change).

IvyZX avatar Oct 14 '23 00:10 IvyZX

jax.experimental has an implementation of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU.

We can probably upstream this to Flax attention if jax.experimental doesn't scare us (it means the API could still change).

Does that support both backprop and forward prop?

aniquetahir avatar Oct 14 '23 01:10 aniquetahir

jax.experimental has an implementation of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU. We can probably upstream this to Flax attention if jax.experimental doesn't scare us (it means the API could still change).

Does that support both backprop and forward prop?

Yes, it seems so.

agemagician avatar Oct 14 '23 07:10 agemagician

jax.experimental has an implementation of FlashAttention, written by Pallas kernels and therefore usable in both GPU and TPU.

We can probably upstream this to Flax attention if jax.experimental doesn't scare us (it means the API could still change).

Yes, but it seems missing some parameters like the dropout, mask and bias, if I am not mistaken.

agemagician avatar Oct 14 '23 07:10 agemagician

One more point Jax has two implementations:

  1. Fused attention: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py
  2. Flash attention: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py

agemagician avatar Oct 14 '23 20:10 agemagician

Update: my comment was made before Pallas was an option, where low level access is simply not possible.

The official jax flash attention implementation is built off on Pallas and thus not subjected to the same issue from 5+months ago. Flash attention should be variable right now, and likely faster than the older mem eff attention I suggested.

Lime-Cakes avatar Oct 17 '23 08:10 Lime-Cakes

Looks like the two attention kernels are for different platforms - one for TPU and another for GPU. The implementations are slightly different for performance reasons, and they are also of different APIs.

It would be cool if we can unify the two and provide a stable attention API, either at Flax level or at JAX level. I will look more into that and collaborate with JAX folks.

For now, feel free to just use one of those kernels based on your platform needs.

IvyZX avatar Oct 19 '23 20:10 IvyZX

Thanks a lot @IvyZX for integerating flash attention.

I am just afraid that it is still missing some parameters like the dropout, mask and bias, if I am not mistaken.

agemagician avatar Oct 19 '23 20:10 agemagician

@IvyZX could you please share the link for the current PR ?

agemagician avatar Oct 26 '23 21:10 agemagician

Dropout is not currently available in Pallas kernels, as it is yet to support PRNG keys. Causal masking can be turned on with causal=True, and the TPU version has attention bias (the arg ab - unclear naming IMO).

I wonder what the real-life use cases for attention bias are, as many LLMs nowadays doesn't use bias any more. Could anyone share why they would like to have it?

IvyZX avatar Oct 27 '23 22:10 IvyZX

If I am not mistaken, many of the state of the art Embedding method requires the attention bias like ALibi and relative attention encoding for T5 models.

agemagician avatar Oct 28 '23 09:10 agemagician

It seems that this triton implementation supports attention bias, so there is nothing that prevents the algorithm from supporting it. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py Additionally, this implementation of block-parallel attention (faster than FA-1) also supports attention bias: https://github.com/lhao499/llm_large_context/blob/main/bpt.py

Edit: Changed bpt.py to the original paper repo

OhadRubin avatar Nov 10 '23 03:11 OhadRubin