[Kernel] Unify the kernel used in flash attention backend
Currently, we are using different kernels for different phases. Concretely, we use flash_attn_with_kvcache for decoding phase and flash_attn_varlen_func for prefill phase and prefix caching. For chunked prefill, we will launch both kernels to handle prefill tokens and decoding tokens separately. The current way has some drawbacks:
- This complicates the attention backend logic because we need
prefill_metadataanddecode_metadata. - Pass more fields from the
model_runnerto the backend than needed becauseflash_attn_with_kvcacheandflash_attn_varlen_funchave different requirements for the input. - Use two kernels for the chunked prefill, which is not performance optimal.
- Potential performance degradation because we need to build
prefill_metadataanddecode_metadataon the fly. But this might be minor since we cache the two metadata.
Moreover, flash_attn_with_kvcache and flash_attn_varlen_func have similar performance as they share the same underlying implementation.
Ideally, we should use a single kernel to handle all cases, including prefill phase, decoding phase, and prefix caching. For chunked prefill, we should just launch a single kernel to handle both the prefill tokens and decoding tokens.
This PR tries to simply the logic in the attention backend and use a single kernel. This is also needed for the MQA scorer (#5691) for speculative decoding.
I think the direction makes sense! It is also more cuda graph friendly approach
- Is this PR ready?
- Original reason why I didn't try this before was that I heard the perf wasn't that different (or worse due to some optimizations for decode case). Can you share the benchmark result?
Yeah, the PR should be ready for review.
Some kernel benchmark numbers on a single A100, all numbers are in ms.
| Number of queries tokens | Number of heads | Head dim | flash_attn_varlen_func | flash_attn_with_kvcache |
|---|---|---|---|---|
| 100 | 12 | 64 | 0.0917 | 0.0365 |
| 500 | 12 | 64 | 0.379 | 0.383 |
| 1000 | 12 | 64 | 1.292 | 1.290 |
| 100 | 32 | 128 | 0.0989 | 0.100 |
| 500 | 32 | 128 | 1.550 | 1.549 |
| 1000 | 32 | 128 | 5.819 | 5.837 |
| 100 | 64 | 128 | 0.161 | 0.160 |
| 500 | 64 | 128 | 2.965 | 3.004 |
| 1000 | 64 | 128 | 11.308 | 11.388 |
Only one case we see great performance degradation
| Number of queries tokens | Number of heads | Head dim | flash_attn_varlen_func | flash_attn_with_kvcache |
|---|---|---|---|---|
| 100 | 12 | 64 | 0.0917 | 0.0365 |
In all other cases, the performance is quite similar.
The review ETA is tonight!
Besides, I'd like to know the e2e performance improvement (or that it matches the performance). Is it possible to run some e2e benchmark with/without the PR and share the result?
Looks like the model output is chaos and totally different after unifying the kernel... I changed the flash_attn.py to the original implementation with flash_attn_with_kvcache for decode and flash_attn_varlen_func for prefill and the result is normal. Have you check the correctness? @LiuXiaoxuanPKU
If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567
My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why
Looks like the model output is chaos and totally different after unifying the kernel... I changed the flash_attn.py to the original implementation with flash_attn_with_kvcache for decode and flash_attn_varlen_func for prefill and the result is normal. Have you check the correctness? @LiuXiaoxuanPKU
Thanks for reporting, will take a look.
If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is
flash_attn_varlen_funcONLY works in NO cuda graph mode... But I not know why
@jjjjohnson Could you provide the model/prompt you used for testing. The results seem correct for basic_correctness. Thanks!
@rkooo567 Some e2e performance numbers of llama-7b on a single H100 with cuda graph. All numbers are 50% percentile request latency in seconds measured with the script.
| input_len | output_len | batch_size | this PR | main branch |
|---|---|---|---|---|
| 32 | 128 | 1 | 0.883 | 0.917 |
| 32 | 128 | 2 | 0.877 | 0.912 |
| 32 | 128 | 4 | 0.906 | 0.933 |
| 32 | 128 | 8 | 0.946 | 0.956 |
| 32 | 128 | 16 | 1.065 | 1.084 |
| 32 | 128 | 32 | 1.236 | 1.259 |
| 512 | 32 | 1 | 0.229 | 0.251 |
| 512 | 32 | 2 | 0.237 | 0.259 |
| 512 | 32 | 4 | 0.267 | 0.288 |
| 512 | 32 | 8 | 0.326 | 0.347 |
| 512 | 32 | 16 | 0.456 | 0.498 |
| 512 | 32 | 32 | 0.699 | 0.779 |
If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is
flash_attn_varlen_funcONLY works in NO cuda graph mode... But I not know why
I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.
Hmm that's pretty odd. there's nothing lora-related in this kernel iiuc
btw I saw a CI failure in LM Eval Small Models as follows
[2024-07-23T14:54:54Z] > assert numpy.isclose(ground_truth, measured_value, rtol=RTOL)
[2024-07-23T14:54:54Z] E assert False
--
| [2024-07-23T14:54:54Z] E + where False = <function isclose at 0x7f72938ba070>(0.593, 0.0, rtol=0.02)
| [2024-07-23T14:54:54Z] E + where <function isclose at 0x7f72938ba070> = numpy.isclose
Looks like the measured_value is 0, so the output may be garbage in this case.
If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is
flash_attn_varlen_funcONLY works in NO cuda graph mode... But I not know whyI now can reproduce the bug with
tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.
I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager
If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is
flash_attn_varlen_funcONLY works in NO cuda graph mode... But I not know whyI now can reproduce the bug with
tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.
Looks like short prompt is OK, if you change to example_long_prompts the tests fails...
example_long_prompts
Qwen/Qwen-14B-Chat
I tried the example_long_prompts with Qwen and it did fail. But after looking into that, it fails for both eager and non-eager mode. It also failed for other backends such as XFORMERS. Therefore, it seems like numerical issues in that case. Did you observe similar things?
If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is
flash_attn_varlen_funcONLY works in NO cuda graph mode... But I not know whyI now can reproduce the bug with
tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager
Could you provide the exact prompt and the hardware you use? After some manual checking on H100 with Qwen/Qwen-14B-Chat. Setting enforce-eager or not give the same output. It might also be possible that bugs with cuda graph preparation are not stable. Thanks!
If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is
flash_attn_varlen_funcONLY works in NO cuda graph mode... But I not know whyI now can reproduce the bug with
tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager
Could you provide the exact prompt and the hardware you use? After some manual checking on H100 with Qwen/Qwen-14B-Chat. Setting enforce-eager or not give the same output. It might also be possible that bugs with cuda graph preparation are not stable. Thanks!
I use A800 TP1.
Prompt:
The rapid advancement in artificial intelligence (AI) has yielded a variety of groundbreaking technologies, among which Large Language Models (LLMs) have garnered widespread attention and utility. LLMs, such as OpenAI’s GPT-4, Google's BERT, and others, have profoundly transformed the landscape of natural language processing (NLP) over the past few years. But what exactly are LLMs, and why are they so significant?At their core, LLMs are a subset of machine learning models designed to understand.LLMs are versatile and can be fine-tuned for a variety of applications. From drafting emails and writing code to translating languages and composing poetry, the potential use cases are vast.
If I change
@pytest.mark.parametrize("backend", ["FLASH_ATTN","XFORMERS"])
to
@pytest.mark.parametrize("backend", ["XFORMERS","FLASH_ATTN"])
The tests get passes... Pretty odd...
@jjjjohnson , when you say the test fails, was the output gibberish or still something reasonable? Changing the kernel may change the numerics slightly?
example_long_prompts
I think there is more likelihood to accumulate numerical error for long prompts so this checks out?
Updates for this PR:
- We will take a less aggressive approach. We will keep the original kernels for prefill and decoding. We will use
flash_attn_varlen_funcfor mixed batch. Mixed batch means batches with prefill tokens and decoding tokens. The goal is to enable cuda graph for chunked prefill and speculative decoding. - We need to debug the cudagraph compability for
flash_attn_varlen_funckernel as it fails the unit tests.
Hi @LiuXiaoxuanPKU
Based on your current test case defined in tests/kernels/test_flash_attn.py; here is a modified version: test_varlen_cg.py
It should pass the given case for mixed prefill and decode now, with vllm_flash_attn v2.6.2. python3 -m pytest test_varlen_cg.py
The major modifications are the following when use flash_attn_varlen_func with cuda graph:
- We need to keep the
max_query_lenandmax_kv_lenstatic, to ensure CPU var takes no effect on results.- Given the
params.num_splitis 0 now and we still provide page table, it will dispatch toflash_fwd_splitkv_kernel. So we need to keep i) all the kernel grid dims static, which uses themax_query_lenii) the kernel template static, which uses themax_kv_len.
- Given the
- Static GPU memory for
g_cu_query_lensandg_cu_kv_lensandg_block_tables; and pad the rest batch index with the non-decreasing seqlens.- Cuz their shape has a batch dimension, we need to keep them static and pad the rest. It thus requires the capture uses the largest number of query token needed. For example, we capture with
[(1, 1), (1, 1), (1,1), (1, 1), (1, 1),(1, 1),(1, 1)]and can run with[[(5, 18), (1, 473), (1, 6),(0,0),(0,0),(0,0),(0,0)]].(0,0)here are needed to keep the prepared padded cu_*_lens non-decreasing, so that GPU blocks responsible for the padded dim won't pollute the result.
- Cuz their shape has a batch dimension, we need to keep them static and pad the rest. It thus requires the capture uses the largest number of query token needed. For example, we capture with
Feel free to try it out. Hope it helps :)
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!
You're trying to merge from a branch that is "2739 commits behind". There is no chance your PR will pass all the CI tests. I advise to re-base your feature branch to maximally track the current state of the main in vllm-project/vllm.
Thanks!