flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Applying FA3 in qwen2 model fine-tuning is slower than FA2

Open 982118809 opened this issue 1 year ago • 8 comments
trafficstars

Hello,I applied FA3 in the fine-tuning of the qwen2 model, using an H800 machine. The test was slower than FA2 under the same conditions.

I used FlashAttnFunc.forward in hopper/flash_attn_interface.py file to replace Qwen2Attention.forward. In the flash_attn_interface.py file added:

from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2Attention,
    Qwen2Model,
    rotate_half,
)

def replace_qwen2_attn_with_flash_attn3():
    Qwen2Attention.forward = FlashAttnFunc.forward

Then, turn off attn_implementation="flash_attention_2" in the fine-tuning code and import the modified part.

Using: 1*H800-80G, 32 cpu, 256 memory qwen2-7b, 45k data, 6.5k training length

In FA3, the speed is about 34s/it image image image

but in FA2, the speed is about 24s/it image image image

And no much difference in memory usage was observed.

May I ask if I did something wrong? Thank you.

982118809 avatar Aug 05 '24 07:08 982118809

Are you using the latest commit? There's a recent update to enable causal for the backward. Can you profile to get the time for the attention kernel?

tridao avatar Aug 05 '24 07:08 tridao

When I started to configure the environment, I also encountered problem #1091 . After fixing this issue, tests were conducted following the successful configuration around August 1st. When was the new commit submitted you mentioned? Was I using the latest commit?

982118809 avatar Aug 05 '24 08:08 982118809

This commit: https://github.com/Dao-AILab/flash-attention/commit/bafe253042fb251a28f351ad0a2657da26263f31

tridao avatar Aug 05 '24 16:08 tridao

OK, I'll use this commit to test it again.

982118809 avatar Aug 08 '24 07:08 982118809

OK, I'll use this commit to test it again.

How about the performance? When I pretrain deepseek-v2 in H100-80G, I met the same(FA3 is slower than FA2)

BlackBearBiscuit avatar Aug 16 '24 03:08 BlackBearBiscuit

Can you profile to get the time for the attention kernel?

tridao avatar Aug 16 '24 03:08 tridao

OK, I'll use this commit to test it again.

How about the performance? When I pretrain deepseek-v2 in H100-80G, I met the same(FA3 is slower than FA2)

Sorry, I'm busy with other things recently. We may wait until FA3 is officially released before using it.

982118809 avatar Aug 16 '24 08:08 982118809

Same issue when finetuning both llama3 and qwen2 model. FA3 takes more time and slightly more GPU space(not sure) than FA2. I replace the same function flash_attn_varlen_func in transformers/modeling_flash_attention_utils.py from FA2 to FA3. Maybe it is not a right way :(

albaNnaksqr avatar Aug 26 '24 10:08 albaNnaksqr

Same issue here. I have a BERT model which has a BertSelfFlashAttention class, inside which the core function about FA is flash_attn_varlen_func. Now I'm trying to run the code from FA2 to FA3, so I just import flash_attn_varlen_func from FA3 instead of FA2. However, the running speed drops.

wangchuan avatar Feb 22 '25 06:02 wangchuan

Can you profile and post how long the FA2 kernel and FA3 kernel take? and what are the input shapes?

tridao avatar Feb 22 '25 15:02 tridao

@tridao Could you give me any hint about the following profile code, which is straight-forward to run?

The results are

100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2632.02it/s]
FA3 costs: 3.8018763065338135
100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2834.98it/s]
FA2 costs: 3.5276150703430176

Environment:

GPU: A100 python 3.9 cuda 12.4 torch 2.6.0 FA2: '2.7.4.post1' FA3: '3.0.0b1'

Code here:

import torch
import time
import tqdm

from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
from flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_v2

def main():
    size = [12288, 12, 32]
    max_seqlen_in_batch_q = max_seqlen_in_batch_k = 2048
    cu_seqlens_q = [0, 2048, 4096,  6144,  8192, 10240, 12288]
    cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, device=torch.device('cuda'))
    cu_seqlens_k = cu_seqlens_q
    query_states = torch.rand(size, dtype=torch.float16, device=torch.device('cuda'))
    key_states = torch.rand(size, dtype=torch.float16, device=torch.device('cuda'))
    value_states = torch.rand(size, dtype=torch.float16, device=torch.device('cuda'))

    N = 10000
    tic = time.time()
    for k in tqdm.tqdm(range(N)):
        attn_output_unpad_v3 = flash_attn_varlen_func_v3(
            query_states,
            key_states,
            value_states,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=max_seqlen_in_batch_q,
            max_seqlen_k=max_seqlen_in_batch_k,
            softmax_scale=None,
            causal=False,
        )
    toc = time.time()
    print(f'FA3 costs: {toc - tic}')

    tic = time.time()
    for k in tqdm.tqdm(range(N)):
        attn_output_unpad_v2 = flash_attn_varlen_func_v2(
            query_states,
            key_states,
            value_states,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=max_seqlen_in_batch_q,
            max_seqlen_k=max_seqlen_in_batch_k,
            dropout_p=0,
            softmax_scale=None,
            causal=False,
        )
    toc = time.time()
    print(f'FA2 costs: {toc - tic}')


if __name__ == '__main__':
    main()

wangchuan avatar Feb 23 '25 01:02 wangchuan

I see you're using hdim 32. FA2 does have a implementation for hdim 32 but FA3 does not specialize for hdim 32, instead round it to hdim 64. You can try hdim 64.

tridao avatar Feb 23 '25 02:02 tridao

Yes, after changing to hdim 64, the speed of FA3 is higher than FA2 now.

size: [12288, 12, 64]
100%|███████████████████████████████████| 10000/10000 [00:04<00:00, 2411.88it/s]
FA3 costs: 4.149170160293579
100%|███████████████████████████████████| 10000/10000 [00:04<00:00, 2041.24it/s]
FA2 costs: 4.899213075637817

However, as my pretrained network is hdim 32, even if I make some trick to use hdim 64 to replace hdim 32 in FA3, the time cost will be 4.14s vs 3.52s (hdim 32 in FA2 as mentioned in previous post). Will it be easy to specialize FA3 for hdim 32?

wangchuan avatar Feb 23 '25 02:02 wangchuan

It's not hard but we don't plan to specialize for hdim 32 in FA3 to reduce compilation time (it's already taking very long to compile). Most model uses hdim 64 or 128.

tridao avatar Feb 23 '25 02:02 tridao

If it is not hard, is it possible for you to give me some hints so that I can adapt your source code (C++/CUDA) to support hdim 32? I can compile by myself. Yes, the compiling time is so long right now, but I really hope to apply FA3 to my model to speed it up.

wangchuan avatar Feb 23 '25 02:02 wangchuan

The code stays the same, just change the dispatching. You want to add to tile_size: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/tile_size.h Then add the new hdim to generate_kernels and setup.py: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/generate_kernels.py Then add to the dispatching: https://github.com/Dao-AILab/flash-attention/blob/06e34f62d18d3a721bc515d4b331a46d5d4c8c09/hopper/flash_api.cpp#L259

tridao avatar Feb 23 '25 03:02 tridao

Really thankful to the hint. I carefully looked into the code, and I guess generate_kernels.py and flash_api.cpp may be not that hard to add hdim 32, however for tile_size, I cannot guess the values. As I am on A100, so what I am interested in is the function tile_size_fwd_sm8x, i.e. adding something before:

if (headdim <= 64) {
            return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false};
        } else if (headdim <= 96) {
            return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false};
        } else if (headdim <= 128) {
            bool const use_8_warps = sm86_or_89 | varlen_and_split;
            return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps};
        } else if (headdim <= 192) {
            bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv;
            return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64};
        } else {
            return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv};
        }

could you let me know what should be returned here? if (headdim <= 32) {return {?????}}

wangchuan avatar Feb 23 '25 08:02 wangchuan

These are all tuned empirically. You can copy the hdim64 case (e.g. 128 x 112). The tile size 128 x 128 should also work.

tridao avatar Feb 23 '25 15:02 tridao

@tridao I successfully implemented hdim32 based on your code, however, the speed is nearly equal to FA2 (higher than before).

size: [12288, 12, 32]
100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2840.81it/s]
FA2 costs: 3.5226664543151855
100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2837.76it/s]
FA3 costs: 3.524158239364624

For tile_size.h, I just modified the function about sm80, adding

if (headdim <= 32) {
            return { 128, varlen_and_split ? 112 : (is_local ? 128 : 144), 4, 1, false };
        }
constexpr std::tuple<int, int, int, int, bool> tile_size_fwd_sm8x(
        bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2,
        bool paged_kv=false, bool varlen_and_split=false,
        bool softcap=false, bool append_kv=false) {
    if (element_size == 2) {
        if (headdim <= 32) {
            return { 128, varlen_and_split ? 112 : (is_local ? 128 : 144), 4, 1, false };
        }
        if (headdim <= 64) {
            return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false};
        } else if (headdim <= 96) {
            return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false};
        } else if (headdim <= 128) {
            bool const use_8_warps = sm86_or_89 | varlen_and_split;
            return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps};
        } else if (headdim <= 192) {
            bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv;
            return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64};
        } else {
            return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv};
        }
    } else {
        // Placeholder for now
        return {128, 64, 8, 2, false};
    }
}

Do you think how can I tune the code to speed it up? Is tile_size.h the only file I need to manually tune?

wangchuan avatar Feb 24 '25 02:02 wangchuan

That sounds right. You can tune tile_size.h but i'm not sure how much that would improve. Hdim 32 is just not very hardware friendly

tridao avatar Feb 24 '25 03:02 tridao