flash-attention-jax icon indicating copy to clipboard operation
flash-attention-jax copied to clipboard

Performance benchmarks?

Open imoneoi opened this issue 2 years ago • 21 comments

Are there any benchmark results now? Looking forward to performance comparisons with original attention, and official torch+CUDA implementation.

imoneoi avatar Dec 07 '22 11:12 imoneoi

I am also curious, additionally maybe it is possible to use cuda code with jax ?

https://github.com/dfm/extending-jax

jakubMitura14 avatar Dec 29 '22 12:12 jakubMitura14

https://colab.research.google.com/drive/1-YCU9ps4gNuROJ3_8MLjSpbICGHaySxh?usp=sharing

OhadRubin avatar Feb 27 '23 15:02 OhadRubin

Fantastic! have you done experiment with the same data on original flash attention ?

jakubMitura14 avatar Feb 27 '23 15:02 jakubMitura14

Not yet

OhadRubin avatar Feb 27 '23 18:02 OhadRubin

Hello, could I ask if this works with TPUs?

jon-chuang avatar Apr 06 '23 22:04 jon-chuang

Here's an updated notebook that precompiles jit and blocks results until ready for anyone interested:

https://colab.research.google.com/drive/11QKRdgMtcivrJNmjTrf2bXTE5yXkXl_Z?usp=sharing

Looks like JAX compiles vanilla attention in a way to be faster than jax flash attention, so no need to change to flash attention if you use JAX.

evanatyourservice avatar Oct 21 '23 15:10 evanatyourservice

Wow this is open from almost a year ago...

I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).

SamuelGabriel avatar Nov 11 '23 09:11 SamuelGabriel

Wow this is open from almost a year ago...

I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).

Would be definitely nice to see such benchmark, but I can imagine how hard is comparing JAX vs PyTorch (GPU/TPU), with many optimized implementations for each device. For PyTorch with GPU we have Triton/CUDA, but JAX recently has also added Triton-like mechanism for writing custom Kernels with GPU/TPU - Pallas. You can even find implementation of attention in it here.

niemiaszek avatar Nov 29 '23 16:11 niemiaszek

@niemiaszek I just recently saw they named and added docs for pallas, looks very interesting. JAX is also improving our ability to customize how networks are sharded across accelerators and are publishing papers on their results wrt efficiency, pretty cool I think. Unfortunately I don't have time to do a fair comparison between torch and jax with attention but it seems that whoever takes the time to delve into it, especially jax's recent improvements, would certainly benefit if they have a need.

Even if we don't take the time, it looks like the jax team continually adds their efficiency findings into jax as defaults so we don't have to implement ourselves.

evanatyourservice avatar Nov 29 '23 16:11 evanatyourservice

from what i've heard, flash attention doesn't work well on TPUs, but i haven't kept up with the latest iteration of their chip design.

Pallas is just a wrapper around Triton, developed at OpenAI for GPUs. you will basically be always limited by what the Triton compiler can do

lucidrains avatar Nov 29 '23 16:11 lucidrains

while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.

lucidrains avatar Nov 29 '23 16:11 lucidrains

@lucidrains I'd agree as far as single-device optimizations go. I solely use jax because my work deals mainly with RL and I've already built everything out, but for things like language and vision models, resources like xformers are hard to beat. I do like jax's work toward multi-device customization especially from an RL perspective.

evanatyourservice avatar Nov 29 '23 16:11 evanatyourservice

while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.

Well, I would argue that in this day, that's no longer such a hard pill given the wide adoption of tiled programming paradigm like Triton (e.g. PyTorch - both codegen + incoming custom kernels, JAX - e.g. Pallas, hardware vendors including NVIDIA, AMD, Intel) which greatly reduces the effort and complexity of getting SOTA perf on GPUs.

jon-chuang avatar Nov 29 '23 17:11 jon-chuang

@jon-chuang hmm, still a bit early to declare that imho

we'll see, i hope so!

lucidrains avatar Nov 29 '23 17:11 lucidrains

Yes, Triton is still not 100% (some matmul kernel size and certain kernels like flash attention backwards are still not SOTA). But it's certainly the direction that industry is investing in, and IMO it's good news for developers and tinkerers who want hackability of each layer of the stack.

I've already heard of some success stories with customizing flash attention kernels via Triton.

jon-chuang avatar Nov 29 '23 17:11 jon-chuang

I think these newish attention replacements will take time to be adopted particularly because the dust has not settled on them and it takes a while for wide-scale experimentation and large-scale training with them to truly prove them out.

IMO all it takes is a leap for a highly-funded industrial lab to go out on a limb and train an LLM with one of these...

For instance, Mistral AI essentially has a linear cost attention mechanism based on SWA - sliding window attention - one could argue of course how effective it is at truly capturing information across long context.

all these frameworks cannot do.

I think this is an overstatement? I think it simply has not been tried out in Triton yet. But it should not be that hard. But whether the performance matches is an open question.

I just hope that more devs become aware of how powerful triton is so that there's more experimentation with implementing these kernels.

jon-chuang avatar Nov 29 '23 17:11 jon-chuang

@jon-chuang yea, let us just agree that we both wish for Triton and the like to succeed so us non-CUDA experts can have control over the entire stack

i just know it isn't there yet.

lucidrains avatar Nov 29 '23 18:11 lucidrains

Interestingly, a basic building block for Mamba (associative scan) already has support in Triton: https://github.com/pytorch/pytorch/issues/95408#issuecomment-1653748896

jon-chuang avatar Nov 30 '23 05:11 jon-chuang

it doesn't support multiple inputs. also i heard it is still buggy in its current state

lucidrains avatar Nov 30 '23 14:11 lucidrains

@jon-chuang anyways, let us take the discussion elsewhere, as this is about flash attention

lucidrains avatar Nov 30 '23 15:11 lucidrains

Flash attention is now available in jax-nightly with a cudnn implementation: jax.nn.dot_product_attention. It only supports Ampere architecture and later.

Note that the default is xla.

MasterSkepticista avatar Jul 21 '24 04:07 MasterSkepticista