flash-attention
flash-attention copied to clipboard
Applying FA3 in qwen2 model fine-tuning is slower than FA2
Hello,I applied FA3 in the fine-tuning of the qwen2 model, using an H800 machine. The test was slower than FA2 under the same conditions.
I used FlashAttnFunc.forward in hopper/flash_attn_interface.py file to replace Qwen2Attention.forward. In the flash_attn_interface.py file added:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2Model,
rotate_half,
)
def replace_qwen2_attn_with_flash_attn3():
Qwen2Attention.forward = FlashAttnFunc.forward
Then, turn off attn_implementation="flash_attention_2" in the fine-tuning code and import the modified part.
Using: 1*H800-80G, 32 cpu, 256 memory qwen2-7b, 45k data, 6.5k training length
In FA3, the speed is about 34s/it
but in FA2, the speed is about 24s/it
And no much difference in memory usage was observed.
May I ask if I did something wrong? Thank you.
Are you using the latest commit? There's a recent update to enable causal for the backward. Can you profile to get the time for the attention kernel?
When I started to configure the environment, I also encountered problem #1091 . After fixing this issue, tests were conducted following the successful configuration around August 1st. When was the new commit submitted you mentioned? Was I using the latest commit?
This commit: https://github.com/Dao-AILab/flash-attention/commit/bafe253042fb251a28f351ad0a2657da26263f31
OK, I'll use this commit to test it again.
OK, I'll use this commit to test it again.
How about the performance? When I pretrain deepseek-v2 in H100-80G, I met the same(FA3 is slower than FA2)
Can you profile to get the time for the attention kernel?
OK, I'll use this commit to test it again.
How about the performance? When I pretrain deepseek-v2 in H100-80G, I met the same(FA3 is slower than FA2)
Sorry, I'm busy with other things recently. We may wait until FA3 is officially released before using it.
Same issue when finetuning both llama3 and qwen2 model. FA3 takes more time and slightly more GPU space(not sure) than FA2. I replace the same function flash_attn_varlen_func in transformers/modeling_flash_attention_utils.py from FA2 to FA3. Maybe it is not a right way :(
Same issue here. I have a BERT model which has a BertSelfFlashAttention class, inside which the core function about FA is flash_attn_varlen_func. Now I'm trying to run the code from FA2 to FA3, so I just import flash_attn_varlen_func from FA3 instead of FA2. However, the running speed drops.
Can you profile and post how long the FA2 kernel and FA3 kernel take? and what are the input shapes?
@tridao Could you give me any hint about the following profile code, which is straight-forward to run?
The results are
100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2632.02it/s]
FA3 costs: 3.8018763065338135
100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2834.98it/s]
FA2 costs: 3.5276150703430176
Environment:
GPU: A100 python 3.9 cuda 12.4 torch 2.6.0 FA2: '2.7.4.post1' FA3: '3.0.0b1'
Code here:
import torch
import time
import tqdm
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
from flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_v2
def main():
size = [12288, 12, 32]
max_seqlen_in_batch_q = max_seqlen_in_batch_k = 2048
cu_seqlens_q = [0, 2048, 4096, 6144, 8192, 10240, 12288]
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, device=torch.device('cuda'))
cu_seqlens_k = cu_seqlens_q
query_states = torch.rand(size, dtype=torch.float16, device=torch.device('cuda'))
key_states = torch.rand(size, dtype=torch.float16, device=torch.device('cuda'))
value_states = torch.rand(size, dtype=torch.float16, device=torch.device('cuda'))
N = 10000
tic = time.time()
for k in tqdm.tqdm(range(N)):
attn_output_unpad_v3 = flash_attn_varlen_func_v3(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
softmax_scale=None,
causal=False,
)
toc = time.time()
print(f'FA3 costs: {toc - tic}')
tic = time.time()
for k in tqdm.tqdm(range(N)):
attn_output_unpad_v2 = flash_attn_varlen_func_v2(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=0,
softmax_scale=None,
causal=False,
)
toc = time.time()
print(f'FA2 costs: {toc - tic}')
if __name__ == '__main__':
main()
I see you're using hdim 32. FA2 does have a implementation for hdim 32 but FA3 does not specialize for hdim 32, instead round it to hdim 64. You can try hdim 64.
Yes, after changing to hdim 64, the speed of FA3 is higher than FA2 now.
size: [12288, 12, 64]
100%|███████████████████████████████████| 10000/10000 [00:04<00:00, 2411.88it/s]
FA3 costs: 4.149170160293579
100%|███████████████████████████████████| 10000/10000 [00:04<00:00, 2041.24it/s]
FA2 costs: 4.899213075637817
However, as my pretrained network is hdim 32, even if I make some trick to use hdim 64 to replace hdim 32 in FA3, the time cost will be 4.14s vs 3.52s (hdim 32 in FA2 as mentioned in previous post). Will it be easy to specialize FA3 for hdim 32?
It's not hard but we don't plan to specialize for hdim 32 in FA3 to reduce compilation time (it's already taking very long to compile). Most model uses hdim 64 or 128.
If it is not hard, is it possible for you to give me some hints so that I can adapt your source code (C++/CUDA) to support hdim 32? I can compile by myself. Yes, the compiling time is so long right now, but I really hope to apply FA3 to my model to speed it up.
The code stays the same, just change the dispatching.
You want to add to tile_size:
https://github.com/Dao-AILab/flash-attention/blob/main/hopper/tile_size.h
Then add the new hdim to generate_kernels and setup.py: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/generate_kernels.py
Then add to the dispatching:
https://github.com/Dao-AILab/flash-attention/blob/06e34f62d18d3a721bc515d4b331a46d5d4c8c09/hopper/flash_api.cpp#L259
Really thankful to the hint. I carefully looked into the code, and I guess generate_kernels.py and flash_api.cpp may be not that hard to add hdim 32, however for tile_size, I cannot guess the values. As I am on A100, so what I am interested in is the function tile_size_fwd_sm8x, i.e. adding something before:
if (headdim <= 64) {
return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false};
} else if (headdim <= 96) {
return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false};
} else if (headdim <= 128) {
bool const use_8_warps = sm86_or_89 | varlen_and_split;
return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps};
} else if (headdim <= 192) {
bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv;
return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64};
} else {
return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv};
}
could you let me know what should be returned here? if (headdim <= 32) {return {?????}}
These are all tuned empirically. You can copy the hdim64 case (e.g. 128 x 112). The tile size 128 x 128 should also work.
@tridao I successfully implemented hdim32 based on your code, however, the speed is nearly equal to FA2 (higher than before).
size: [12288, 12, 32]
100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2840.81it/s]
FA2 costs: 3.5226664543151855
100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2837.76it/s]
FA3 costs: 3.524158239364624
For tile_size.h, I just modified the function about sm80, adding
if (headdim <= 32) {
return { 128, varlen_and_split ? 112 : (is_local ? 128 : 144), 4, 1, false };
}
constexpr std::tuple<int, int, int, int, bool> tile_size_fwd_sm8x(
bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2,
bool paged_kv=false, bool varlen_and_split=false,
bool softcap=false, bool append_kv=false) {
if (element_size == 2) {
if (headdim <= 32) {
return { 128, varlen_and_split ? 112 : (is_local ? 128 : 144), 4, 1, false };
}
if (headdim <= 64) {
return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false};
} else if (headdim <= 96) {
return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false};
} else if (headdim <= 128) {
bool const use_8_warps = sm86_or_89 | varlen_and_split;
return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps};
} else if (headdim <= 192) {
bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv;
return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64};
} else {
return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv};
}
} else {
// Placeholder for now
return {128, 64, 8, 2, false};
}
}
Do you think how can I tune the code to speed it up? Is tile_size.h the only file I need to manually tune?
That sounds right. You can tune tile_size.h but i'm not sure how much that would improve. Hdim 32 is just not very hardware friendly