dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Flash Attention

Open jafioti opened this issue 3 years ago • 7 comments

Related to #590

In my quest to speed up my language model training, I've been looking to implement Flash Attention (https://github.com/HazyResearch/flash-attention). They report linear storage requirements (still quadratic time requirements) and much faster perf due to I/O optimization.

Would be a great addition to have added to MHA. Also, adding ALiBi biases (https://arxiv.org/abs/2108.12409) and causal masking directly in the kernel would be huge for perf, but I don't know how that could be added in a nice modular way, so it might make sense to just do the simple flash attention.

jafioti avatar Apr 14 '23 13:04 jafioti

cuda source for flash attention for later: https://github.com/HazyResearch/flash-attention/tree/96d10f654527cc82c81022e16f77a8d9564f7eba/csrc/flash_attn

chelsea0x3b avatar Apr 19 '23 13:04 chelsea0x3b

A couple notes for this:

  1. HazyResearch kernels are pretty intertwined with libtorch. E.g. they use aten library a number of important places. This makes using it directly very complex.
  2. triton has an example of fused attention kernel written in python. https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py. It may be possible to use triton to compile kernels, but a number of things would need to be added:
    1. Is it possilbe to get triton to output PTX files for ingestion by dfdx?
    2. Is it possible to invoke triton compiler from rust? Currently there is no rust frontend. There is a tracking issue https://github.com/openai/triton/issues/153, but seems stale

chelsea0x3b avatar May 09 '23 16:05 chelsea0x3b

@coreylowman I know the answer to i) is yes, you can directly get a ptx file from a compiled kernel.

I don't know about calling triton from rust. Certianly for now we can just use triton from python to get out a ptx file and import it manually into dfdx. Would be great if in the future that could be automatic, but idk

jafioti avatar May 09 '23 17:05 jafioti

Yeah I like that idea, seems easy enough to try as well.

chelsea0x3b avatar May 09 '23 18:05 chelsea0x3b

Has there been any progress on this? Apparently there is now a flash attention 2 as well, same repo. Here is the TR https://tridao.me/publications/flash2/flash2.pdf

Reports signficant increased speed over original flash attention.

vikigenius avatar Jul 18 '23 01:07 vikigenius

Also I did follow up on the Triton thread https://github.com/openai/triton/issues/153 and it seems like even though https://github.com/openai/triton/pull/1056 got closed https://github.com/openai/triton/pull/1805 did get merged. I am not sure how much more work is needed on top of that to get a functioning Rust frontend.

vikigenius avatar Jul 18 '23 01:07 vikigenius

Yeah I saw the flash attention 2, I think it still has the same issue of being very tightly coupled with libtorch. If triton merged the AOT changes, then I think all we need to do is compile a PTX for the flash attention kernel? We'd likely include the actual PTX in dfdx and then we'd have to make a flash attention tensor op

chelsea0x3b avatar Jul 19 '23 12:07 chelsea0x3b