vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Kernel] Unify the kernel used in flash attention backend

Open LiuXiaoxuanPKU opened this issue 1 year ago • 2 comments

Currently, we are using different kernels for different phases. Concretely, we use flash_attn_with_kvcache for decoding phase and flash_attn_varlen_func for prefill phase and prefix caching. For chunked prefill, we will launch both kernels to handle prefill tokens and decoding tokens separately. The current way has some drawbacks:

  1. This complicates the attention backend logic because we need prefill_metadata and decode_metadata.
  2. Pass more fields from the model_runner to the backend than needed because flash_attn_with_kvcache and flash_attn_varlen_func have different requirements for the input.
  3. Use two kernels for the chunked prefill, which is not performance optimal.
  4. Potential performance degradation because we need to build prefill_metadata and decode_metadata on the fly. But this might be minor since we cache the two metadata.

Moreover, flash_attn_with_kvcache and flash_attn_varlen_func have similar performance as they share the same underlying implementation. Ideally, we should use a single kernel to handle all cases, including prefill phase, decoding phase, and prefix caching. For chunked prefill, we should just launch a single kernel to handle both the prefill tokens and decoding tokens.

This PR tries to simply the logic in the attention backend and use a single kernel. This is also needed for the MQA scorer (#5691) for speculative decoding.

LiuXiaoxuanPKU avatar Jul 02 '24 00:07 LiuXiaoxuanPKU

I think the direction makes sense! It is also more cuda graph friendly approach

QQ

  1. Is this PR ready?
  2. Original reason why I didn't try this before was that I heard the perf wasn't that different (or worse due to some optimizations for decode case). Can you share the benchmark result?

rkooo567 avatar Jul 02 '24 11:07 rkooo567

Yeah, the PR should be ready for review.

Some kernel benchmark numbers on a single A100, all numbers are in ms.

Number of queries tokens Number of heads Head dim flash_attn_varlen_func flash_attn_with_kvcache
100 12 64 0.0917 0.0365
500 12 64 0.379 0.383
1000 12 64 1.292 1.290
100 32 128 0.0989 0.100
500 32 128 1.550 1.549
1000 32 128 5.819 5.837
100 64 128 0.161 0.160
500 64 128 2.965 3.004
1000 64 128 11.308 11.388

Only one case we see great performance degradation

Number of queries tokens Number of heads Head dim flash_attn_varlen_func flash_attn_with_kvcache
100 12 64 0.0917 0.0365

In all other cases, the performance is quite similar.

LiuXiaoxuanPKU avatar Jul 09 '24 04:07 LiuXiaoxuanPKU

The review ETA is tonight!

Besides, I'd like to know the e2e performance improvement (or that it matches the performance). Is it possible to run some e2e benchmark with/without the PR and share the result?

rkooo567 avatar Jul 15 '24 23:07 rkooo567

Looks like the model output is chaos and totally different after unifying the kernel... I changed the flash_attn.py to the original implementation with flash_attn_with_kvcache for decode and flash_attn_varlen_func for prefill and the result is normal. Have you check the correctness? @LiuXiaoxuanPKU

jjjjohnson avatar Jul 22 '24 02:07 jjjjohnson

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

jjjjohnson avatar Jul 22 '24 13:07 jjjjohnson

Looks like the model output is chaos and totally different after unifying the kernel... I changed the flash_attn.py to the original implementation with flash_attn_with_kvcache for decode and flash_attn_varlen_func for prefill and the result is normal. Have you check the correctness? @LiuXiaoxuanPKU

Thanks for reporting, will take a look.

LiuXiaoxuanPKU avatar Jul 22 '24 13:07 LiuXiaoxuanPKU

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

@jjjjohnson Could you provide the model/prompt you used for testing. The results seem correct for basic_correctness. Thanks!

LiuXiaoxuanPKU avatar Jul 23 '24 12:07 LiuXiaoxuanPKU

@rkooo567 Some e2e performance numbers of llama-7b on a single H100 with cuda graph. All numbers are 50% percentile request latency in seconds measured with the script.

input_len output_len batch_size this PR main branch
32 128 1 0.883 0.917
32 128 2 0.877 0.912
32 128 4 0.906 0.933
32 128 8 0.946 0.956
32 128 16 1.065 1.084
32 128 32 1.236 1.259
512 32 1 0.229 0.251
512 32 2 0.237 0.259
512 32 4 0.267 0.288
512 32 8 0.326 0.347
512 32 16 0.456 0.498
512 32 32 0.699 0.779

LiuXiaoxuanPKU avatar Jul 23 '24 13:07 LiuXiaoxuanPKU

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

LiuXiaoxuanPKU avatar Jul 23 '24 20:07 LiuXiaoxuanPKU

Hmm that's pretty odd. there's nothing lora-related in this kernel iiuc

rkooo567 avatar Jul 23 '24 20:07 rkooo567

btw I saw a CI failure in LM Eval Small Models as follows

[2024-07-23T14:54:54Z] >               assert numpy.isclose(ground_truth, measured_value, rtol=RTOL)
[2024-07-23T14:54:54Z] E               assert False
--
  | [2024-07-23T14:54:54Z] E                +  where False = <function isclose at 0x7f72938ba070>(0.593, 0.0, rtol=0.02)
  | [2024-07-23T14:54:54Z] E                +    where <function isclose at 0x7f72938ba070> = numpy.isclose

Looks like the measured_value is 0, so the output may be garbage in this case.

comaniac avatar Jul 23 '24 21:07 comaniac

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager

jjjjohnson avatar Jul 24 '24 06:07 jjjjohnson

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

Looks like short prompt is OK, if you change to example_long_prompts the tests fails... image

jjjjohnson avatar Jul 24 '24 09:07 jjjjohnson

example_long_prompts

Qwen/Qwen-14B-Chat

I tried the example_long_prompts with Qwen and it did fail. But after looking into that, it fails for both eager and non-eager mode. It also failed for other backends such as XFORMERS. Therefore, it seems like numerical issues in that case. Did you observe similar things?

LiuXiaoxuanPKU avatar Jul 24 '24 11:07 LiuXiaoxuanPKU

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager

Could you provide the exact prompt and the hardware you use? After some manual checking on H100 with Qwen/Qwen-14B-Chat. Setting enforce-eager or not give the same output. It might also be possible that bugs with cuda graph preparation are not stable. Thanks!

LiuXiaoxuanPKU avatar Jul 24 '24 11:07 LiuXiaoxuanPKU

If I add --enforce-eager, which disables cuda graph, the model output text is normal. But if I enable cuda graph, the output is totaly different. @comaniac @rkooo567 My guess is flash_attn_varlen_func ONLY works in NO cuda graph mode... But I not know why

I now can reproduce the bug with tests/lora/test_chatglm3.py, where if I put enforce-eager, the test can pass, otherwise the test fails. I'm wondering if your case is related to lora. I cannot reproduce the bug without lora.

I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager

Could you provide the exact prompt and the hardware you use? After some manual checking on H100 with Qwen/Qwen-14B-Chat. Setting enforce-eager or not give the same output. It might also be possible that bugs with cuda graph preparation are not stable. Thanks!

I use A800 TP1. Prompt: The rapid advancement in artificial intelligence (AI) has yielded a variety of groundbreaking technologies, among which Large Language Models (LLMs) have garnered widespread attention and utility. LLMs, such as OpenAI’s GPT-4, Google's BERT, and others, have profoundly transformed the landscape of natural language processing (NLP) over the past few years. But what exactly are LLMs, and why are they so significant?At their core, LLMs are a subset of machine learning models designed to understand.LLMs are versatile and can be fine-tuned for a variety of applications. From drafting emails and writing code to translating languages and composing poetry, the potential use cases are vast. image

If I change
@pytest.mark.parametrize("backend", ["FLASH_ATTN","XFORMERS"]) to @pytest.mark.parametrize("backend", ["XFORMERS","FLASH_ATTN"]) The tests get passes... Pretty odd... image

jjjjohnson avatar Jul 25 '24 08:07 jjjjohnson

@jjjjohnson , when you say the test fails, was the output gibberish or still something reasonable? Changing the kernel may change the numerics slightly?

example_long_prompts

I think there is more likelihood to accumulate numerical error for long prompts so this checks out?

jon-chuang avatar Aug 09 '24 00:08 jon-chuang

Updates for this PR:

  1. We will take a less aggressive approach. We will keep the original kernels for prefill and decoding. We will use flash_attn_varlen_func for mixed batch. Mixed batch means batches with prefill tokens and decoding tokens. The goal is to enable cuda graph for chunked prefill and speculative decoding.
  2. We need to debug the cudagraph compability for flash_attn_varlen_func kernel as it fails the unit tests.

LiuXiaoxuanPKU avatar Aug 19 '24 18:08 LiuXiaoxuanPKU

Hi @LiuXiaoxuanPKU

Based on your current test case defined in tests/kernels/test_flash_attn.py; here is a modified version: test_varlen_cg.py

It should pass the given case for mixed prefill and decode now, with vllm_flash_attn v2.6.2. python3 -m pytest test_varlen_cg.py

The major modifications are the following when use flash_attn_varlen_func with cuda graph:

  • We need to keep the max_query_len and max_kv_len static, to ensure CPU var takes no effect on results.
    • Given the params.num_split is 0 now and we still provide page table, it will dispatch to flash_fwd_splitkv_kernel. So we need to keep i) all the kernel grid dims static, which uses the max_query_len ii) the kernel template static, which uses the max_kv_len.
  • Static GPU memory for g_cu_query_lens and g_cu_kv_lens and g_block_tables; and pad the rest batch index with the non-decreasing seqlens.
    • Cuz their shape has a batch dimension, we need to keep them static and pad the rest. It thus requires the capture uses the largest number of query token needed. For example, we capture with [(1, 1), (1, 1), (1,1), (1, 1), (1, 1),(1, 1),(1, 1)] and can run with [[(5, 18), (1, 473), (1, 6),(0,0),(0,0),(0,0),(0,0)]]. (0,0) here are needed to keep the prepared padded cu_*_lens non-decreasing, so that GPU blocks responsible for the padded dim won't pollute the result.

Feel free to try it out. Hope it helps :)

pengwu22 avatar Sep 20 '24 02:09 pengwu22

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @LiuXiaoxuanPKU.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Nov 26 '24 05:11 mergify[bot]

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

github-actions[bot] avatar Feb 25 '25 02:02 github-actions[bot]

You're trying to merge from a branch that is "2739 commits behind". There is no chance your PR will pass all the CI tests. I advise to re-base your feature branch to maximally track the current state of the main in vllm-project/vllm.

Thanks!

Alexei-V-Ivanov-AMD avatar Feb 25 '25 16:02 Alexei-V-Ivanov-AMD