flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Turning support?

Open Ph0rk0z opened this issue 1 year ago • 16 comments

A while it said that turning support is coming. What is preventing the compilation on turning now? Is it the lack of BF16? Can that just be disabled for turning or disabled overall with a compile flag? Is there any more to it? There are now some 22g 2080Ti so it would be cool to have them work along with ampere cards.

Ph0rk0z avatar Jan 24 '24 12:01 Ph0rk0z

I'm one-upping this for use on Google Colab T4 environment.

GloomyC avatar Jan 25 '24 16:01 GloomyC

Unfortunately I haven't had much bandwidth.

tridao avatar Jan 25 '24 18:01 tridao

I was able to get everything to compile, but in some cases I get an "invalid argument" cuda assertion. Also not sure of the performance.

Ph0rk0z avatar Jan 31 '24 01:01 Ph0rk0z

Turing cards have less shared memory (64KB instead of 99KB or 163KB on Ampere) so that might require adjusting the block sizes currently used.

tridao avatar Jan 31 '24 01:01 tridao

It worked on small outputs but then failed on large ones. I have to figure out to see where. I didn't edit the block size section that is mainly done for sm8x/90 and it's a likely culprit. I am really only using the forward pass for exllamav2.

edit: I tried limiting the kernel to only blocks of 32 but I'm getting the invalid argument error still. I spent all day trying to figure out how to get a line number out of pytorch, but it seems I need to get a version with DSA enabled (or compile pt, ouch) or some other method of debugging.

Ph0rk0z avatar Jan 31 '24 12:01 Ph0rk0z

Interested in this too.

Elsayed91 avatar Mar 01 '24 13:03 Elsayed91

Also interested in this to run in AWS Graviton servers.

elkay avatar Mar 11 '24 13:03 elkay

Interested in this too.

laoda513 avatar Apr 01 '24 13:04 laoda513

Interested in this as well.

I am working now on some petproject combining "the era of 1-bit LLM" paper approach (efficiently showing that 1.5 bit per parameter is enough) with a ReLoRA paper (since unlike the original "1-bit LLM" guys who shown that 1.5 of information capacity is enough - I actually use models quantized with that approach, so I need the method which does not bypass gradients through the main model weights).

However, even so it still means I need to produce quite a big intermediate matrices during forward-and-backward pass, which Flash Attention do on the fly, as far as I understood?

So it would be quite nice to see how all the three techniques combined will work (quite a minimal model weights + no big gradients & optimizer state due to the ReLoRA + less intermediate matrices spawned because of flashattention).

alex4321 avatar Apr 05 '24 21:04 alex4321

Turing cards have less shared memory (64KB instead of 99KB or 163KB on Ampere) so that might require adjusting the block sizes currently used.

Does this cause the need to spend a lot of effort to complete support?

chuanzhubin-aiopx avatar Apr 15 '24 01:04 chuanzhubin-aiopx

That is editable in one file as far as I saw.

Ph0rk0z avatar Apr 15 '24 11:04 Ph0rk0z

OpenAI's Triton implementation of flash attention works on Turing GPUs (just tested this myself):

https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py

rationalism avatar May 01 '24 01:05 rationalism

Llama.cpp also has an implementation. The only problem with the vllm/oai implementation is that it isn't drop-in like the one here.

Ph0rk0z avatar May 01 '24 22:05 Ph0rk0z

+1

AvivSham avatar May 02 '24 15:05 AvivSham

@rationalism how do you install that version you mentioned? I need to use it with InternVL mentioned here. Thanks.

bit-scientist avatar Sep 06 '24 06:09 bit-scientist

I know that version. It's slow and it is missing a lot of stuff. I think vllm supports it. Anyways, I made a version that "works" out of iirc 2.5.x?. Unfortunately it cranks for a while and then fails with "invalid parameter". I printed all the tensors and they appear normal until I hit the error. Even shrunk the sram on some of the kernels for it to fit. I should post it up at some point, maybe someone will be able to figure it out that's better than me. Basically dies half way into an output. Only care about the forward pass and not training.

https://github.com/Ph0rk0z/flash-attn-turning/commit/b3dc600f3916528105462e918434921f1dc65459

Ph0rk0z avatar Sep 06 '24 12:09 Ph0rk0z