Request to support FlashAttention in cuda attention.cc
FlashAttention can largely avoid memory usage and speeds up attention even in the process of inference. Any plan to support this implementation: https://github.com/facebookresearch/xformers/tree/main/xformers/csrc/attention
Hi, New implementation was release : https://tridao.me/publications/flash2/flash2.pdf With 50% TFLOPS improvement on the forward pass comparing to the old FlashAttention implementation, and massive improvement comparing to the vanilla Attention mecanism https://github.com/Dao-AILab/flash-attention
Hi @guillaumekln! How hard would it be to implement this?
Can you maybe give us some pointers? Most of the popular libraries support v1.0 already
Hi, I think the V2 will be much simpler to implement as it comes with an higher level library and much more compatible GPU. It might also restrict the gpus on which Ctranslate2 can run or we need to add a special field to request for flash attention
At this time FlashAttention is mostly useful for training or when processing a long prompt. However, during inference most of the time is usually spent in the iterative decoding where the bottleneck is somewhere else.
It seems the author is still working to optimize this inference part where the input shape is different: https://github.com/Dao-AILab/flash-attention/issues/346#issuecomment-1642787240
Also I can't find an end-to-end benchmark using FlashAttention for inference. Do you know where we can find one? I only see benchmarks for end-to-end training or for the attention module itself.
But I believe it will be able to reduce the VRAM usage further. Can we get some support to run memory efficient attention?
Hello all. Just thought I'd post a question about Flash Attention 2 here:
https://github.com/Dao-AILab/flash-attention
Apparently it's making big waves and seems seems very powerful. Does anyone plan on seeing if it's something that could be included?
I reviewed the prior comments and suggest that we change the topic to Flash Attention 2. I know that guillaumkeln is no longer with faster-whisper, but hopefully one of the admins can weight in on this possibly powerful feature to include in ctranslate2!!
Hi my thought on this, they are some major pros and some cons :
Pro :
- Reduce VRAM usage,
- Flash-decoding improve speed on long sequence generation (Don't know if something similar is already implemented
- Faster inference
Cons :
- Introduce a dependency
- It's not compatible with all GPU, so it will be tricky to work with
Is it possible to have your thoughts on this dev ? What will be the work required ?
Eventually this will be included, but it is not the same story to include a pip package (and we did include flash2 in OpenNMT-py) and link a cpp package that is moving quite frequently. Of course we don't want to drop the current path for scaled dot attention. It takes time but new cpp developers are very welcome.
bear in mind that native pytorch is not dead: https://pytorch.org/blog/accelerating-generative-ai-2/?hss_channel=lcp-78618366 https://forum.opennmt.net/t/opennmt-py-v3-4-3-released-blazing-fast-beam-search-inference/5546
Some other repo claimed flash attention will be helpful to make transcribe much faster: https://github.com/Vaibhavs10/insanely-fast-whisper To my read, 2x?
Ctranslate2 supports soon the flash attention 2 following this PR #1651. I will do the release asap. I made some tests and saw an improvement in performance with long prompt. It would run on GPU architecture >= sm80 only as mentioned in the original repo. It would be great if you guys could test it.
thanks, looking forward to test it with faster-whisper!
Ctranslate2 supports soon the flash attention 2 following this PR #1651. I will do the release asap. I made some tests and saw an improvement in performance with long prompt. It would run on GPU architecture >= sm80 only as mentioned in the original repo. It would be great if you guys could test it.
thanks, looking forward to test it with faster-whisper!
This is great! Any chance you could provide some tips as to how to test this on faster-whisper?
This is great! Any chance you could provide some tips as to how to test this on faster-whisper?
Make sure you have Ampere GPU or newer. You can just set flash_attention=True when loading model to use Flash attention instead of stand MHA.
Hi @minhthuc2502, Do you have a benchmark comparing Faster Whisper with and without Flash Attention?
Hello, I did not make a benchmark with Faster Whisper, but there is some benchmark for Flash Attention with some LLM models here.
Hi @minhthuc2502, Do you have a benchmark comparing Faster Whisper with and without Flash Attention?
I haven't benched Whisper in relation to flash attention, but my hypothesis is that it will not make much of a difference for a beam size of 1 but that it "might" if the beam size is increased. However, the benefit will likely not be nearly as great as with stereotypical chat models. I deduce this conclusion based on the following:
-
My testing of flash attention indicates a noticeable VRAM savings and speedup for chat models run with ctranslate2 except for Llama2 models (likely due to architectural differences), but that this is most noticeable when you increase the beam size. Thus, FA2 seems to provide improvements "across the board" when you increase the beam size and there's no indication that this wouldn't also be the case for Whisper (as opposed to chat) models.
-
However, my testing was geared towards a "RAG" use case. This scenario involves sending a single question to an LLM for a response, and accompany the question with "contexts" from a corpus, the goal being for the LLM to respond solely based on the provides contexts. The question and the provided contexts, together, constitute the "prompt" for the LLM to process. The "prompt" in my testing was approximately 1000 tokens.
-
In the linked conversation that @minhthuc2502 provided he states that the benefits of FA2 should be greater the longer the "prompt." Since my "prompt" was only ~1000 tokens, and if what @minhthuc2502 says is true, it means that I didn't fully test the benefits of flash attention in ctranslate2...again, my testing was geared towards RAG.
-
In a non-RAG scenario, such as when you converse with a chat LLM in a multi-turn conversation, the entire conversation is sent to the LLM each time...and each time the user's new message or the LLM's response is appended to the chat history and resent to the LLM. This is commonly referred to as "memory" and is different than a single question like my RAG scenario. In a conversation with memory, the chat history can easily increase above 1000 tokens and will often exceed the LLM's context window. Again, I didn't test for a "prompt" above 1000 tokens.
-
With this background, the Whisper models themselves can only process up to 30 seconds of audio in a given chunk. This is an inherent limitation based on how the Whisper models were trained by OpenAI. The VAD (e.g. see the faster-whisper repo) removes silent portions of the 30 second window so as to pack only speech into it, but the 30 second window remains...
-
As such, you won't see the benefit of flash attention with whisper because - unless you can cram way more than 1000 tokens in that 30-second window - you won't see the benefit that @minhthuc2502 mentions based on a longer "prompt" sent to the LLM. However, as I mentioned, this doesn't disturb my findings regarding the benefits of flash attention when increasing the beam size even with smaller 1000 token chunks.
Keep in mind that I just haven't had time to test this. In my testing I try to honestly represent peoples' hard work, but I'm not a programmer by trade and this is a hobby of mine so...take it with a grain of salt. Hope this helps!
Hi @BBC-Esq, thank you for your insights! Regarding whisper and number of tokens - every 30 sec window is converted to Mel-Spec features which are equal to 30K tokens each with 80 features. Therefore I expected to see some boost when using FA. Additionally, the default beam size for faster whisper is 5.
@minhthuc2502 The reason I was asking about faster whisper FA benchmark is that I do not see any improvement in speed when loading the whisper model with FA.
Here is the code I used to benchmark:
import time
import torch
from faster_whisper import WhisperModel
def benchmark(model):
times = []
# Warmup
for i in range(10):
segments, _ = model.transcribe(
"sample_1.wav",
language="fr",
)
segments = list(segments)
# Benchmark
for i in range(100):
segments, _ = model.transcribe(
"sample_1.wav",
language="fr"
)
past = time.time()
segments = list(segments)
torch.cuda.synchronize()
times.append(time.time() - past)
times = times[1:]
print(f"Mean inference time: {sum(times) / len(times)}")
print(f"\nTIMES: {times}")
if __name__ == '__main__':
# model = WhisperModel("/home/user/whisper-large-v2-ct2", flash_attention=True)
model = WhisperModel("/home/user/whisper-large-v2-ct2", flash_attention=False)
benchmark(model)
The results for the above code snip are (after running it twice independently):
With FA:
Mean inference time: 0.8763072201699922
W/O FA:
Mean inference time: 0.8619994466955011
About the setup:
- the audio file is ~50 sec long.
- GPU: a10G
- ctranslate2 version: 4.2.1
Is this result expected? if not what can be done to make it faster?
@AvivSham If you're asking for my opinion on how to speed things up generally, faster-whisper has a pull request for batch processing that's not yet approved. If you don't want to wait for it you can use the WhisperS2T library, but once the faster-whisper pull request is approved the speed will be comparable to that of WhisperS2T.
But if you're asking how to make it faster with flash attention, based on the assumption that you might not be using flash attention correctly with faster-whisper...afraid I can't really help. @minhthuc2502 might be able to help, but what I've learned is that those kinds of questions are better posted on the faster-whisper repo. Those peeps are more responsible for actually implementing new features provided by ctranslate2. I.e., unless there's a problem with ctranslate2's implementation of flash attention for whisper models IN GENERAL, the issue would be better addressed at faster-whisper.
With that being said, I can confirm that flash attention works for "chat" models so I'd be surprised if there's some kind of core issue with the ctranslate2 library that prevents it from working just with Whisper models...
BTW, when I said "I can't really help" it's not that I don't want to...it's just that I'm tapped out as far as my personal knowledge...Programming is only hobby for me after all. ;-)
@AvivSham You might also test your script using beam sizes 1-5 and see if there's a difference? If there's a noticeable difference between using flash attention and not, you could perhaps eliminate the variable that somehow the flash attention parameter isn't being used at all? At the end of this discussion they do confirm that flash attention can be used...
https://github.com/SYSTRAN/faster-whisper/issues/598
Thank you for your attempt to help! 😄 I will post this question directly in the faster-whisper repo while waiting for @minhthuc2502 's response.
For more information, I executed some benchmarks for Faster whisper with FlashAttention in here.
Thank you for your attempt to help! 😄 I will post this question directly in the
faster-whisperrepo while waiting for @minhthuc2502 's response.
With recent tests, I posted a benchmarks with FA2, I noticed that with longer sequence length, I can see more obviously the difference between FA2 and standard MHA. Otherwise, in case of faster whisper, the 30 seconds audio chunk will be converted to an encoder's input with the shape (1,80,3000), see here. The sequence length is quite small to get the benefit of FA2.