composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

FlashAttention train kernels

Open danyao12 opened this issue 1 year ago • 2 comments

FlashAttentionV1: forward kloop: gridwise_batched_mha_fwd_xdl_cshuffle_v1.hpp backward kloop prototype1: gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp backward kloop prototype2: gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp

FlashAttentionV2: forward kloop: gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp backward qloop from bottom to top prototype1: gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp backward qloop from bottom to top prototype2: gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp backward qloop from bottom to top prototype1 w/o d calculation: gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp backward qloop from bottom to top prototype2 w/o d calculation: gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp backward qloop d = rowsum(do*o): gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp

Tasks in progress:

  • [x] Qloop bwd 2 split kernels waiting to merge
  • [x] Bias function
  • [x] CausalFromBottomRight
  • [x] Grouped Query Attention/Multi Query Attention

Using the following script for testing Run inferring examples

#!/bin/bash

TESTS=" example_batched_gemm_scale_softmax_gemm_xdl_fp16                 \
          example_batched_gemm_scale_softmax_gemm_xdl_bf16                 \
          example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16         \
          example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16         \
          example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16         \
          example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 \
          example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16" 

FAILED_TESTS=""

for test in $TESTS; do 
    ./bin/$test; 
    if test $? -eq 0 ; then 
       echo $test succeeded!; 
    else
       FAILED_TESTS="$FAILED_TESTS $test";  
    fi;
done

for test in $FAILED_TESTS; do
    echo $test failed; 
done

Run training examples

#!/bin/bash

TESTS=" example_grouped_multihead_attention_forward_v1                                    \
          example_batched_multihead_attention_forward_v1                                    \
          example_grouped_multihead_attention_backward_v1                                   \
          example_batched_multihead_attention_backward_v1                                   \
          example_grouped_multihead_attention_train_v1                                      \
          example_batched_multihead_attention_train_v1                                      \
          example_grouped_multihead_attention_forward_v2                                    \
          example_batched_multihead_attention_forward_v2                                    \
          example_grouped_multihead_attention_backward_v2                                   \
          example_batched_multihead_attention_backward_v2                                   \
          example_grouped_multihead_attention_train_v2                                      \
          example_batched_multihead_attention_train_v2                                       \
          example_grouped_multihead_attention_backward_v3                                  \
          example_batched_multihead_attention_backward_v3"

FAILED_TESTS=""

for test in $TESTS; do 
    ./bin/$test; 
    if test $? -eq 0 ; then 
       echo $test succeeded!; 
    else
       FAILED_TESTS="$FAILED_TESTS $test";  
    fi;
done

for test in $FAILED_TESTS; do
    echo $test failed; 
done

danyao12 avatar Jul 25 '23 02:07 danyao12

Because qloop's fwd and bwd use different layouts, so we refactor dropout to decouple fwd and bwd, but dropout after refactor brought a lot of overhead to fwd, there are three main reasons: 1. Because the shuffle of dropout is performed in 4*4 units, it wastes the get_random_8x16 instruction, which generates eight random numbers per call by each thread, but uses only the first four of them, 2. Transpose shuffle of random number matrix through LDS will cause bankconflict, 3. Fwd headdim128 has some scratch before and after dropout refactor.

I haven't thought of a good solution for point2. For point1, I optimized dropout using 16*16 units to make full use of the get_random_8x16 instruction, for example, random numbers are exchanged between Lane0-15 and Lane32-47 through LDS in fwd kernel. For point3, when we set the Gemm1NPerBlock to headdim(128), the TotalNumVgprs will already be close to 250 even without dropout, so it will be difficult to optimize dropout to eliminate scratch, so in this case I tried to halve the Gemm1NPerBlock (64), which will double the number of workgroups, but will eliminate scratch. It should be noted that when lse and dropout random martix are saved, do not write out repeatedly.

The following is a comparison before and after optimization. I think later tuning kernels with dropout using CK profiler will get better performance.

Dtype Device Train Microbatch size Num_heads Head_dim Seqlen Casual Dropout Fwd Before Bwd Before Fwd After Bwd After Improved
fp16 16 12 64 1K FALSE 0.2 1.39495 ms, 36.9472 TFlops 3.60971 ms, 35.6952 Tflops 1.09508 ms, 47.0647 Tflops 3.09402 ms, 41.6445 TFlops 22.12%
fp16 16 12 64 2K FALSE 0.2 5.38146 ms, 38.309 TFlops 13.9454 ms, 36.958 Tflops 4.22661 ms, 48.7764 Tflops 11.8978 ms, 43.3185 Tflops 22.36%
fp16 8 16 64 1K FALSE 0.2 0.941686 ms, 36.4875 Tflops 2.41688 ms, 35.5414 Tflops 0.741349 ms, 46.3476 Tflops 2.078 ms, 41.3375 Tflops 21.74%
fp16 8 16 64 2K FALSE 0.2 3.60951 ms, 38.0769 Tflops 9.31209 ms, 36.898 Tflops 2.8365 ms, 48.4537 Tflops 7.93545 ms, 43.2991 Tflops 22.38%
fp16 8 16 64 4K FALSE 0.2 14.1361 ms, 38.8901 Tflops 36.5511 ms, 37.6019 Tflops 11.2258 ms, 48.9723 TFlops 31.1255 ms, 44.1564 Tflops 21.75%
fp16 4 16 128 2K FALSE 0.2 4.00785 ms, 34.2924 Tflops 8.70397 ms, 39.4759 Tflops 3.3851 ms, 40.6012 Tflops 8.44611 ms, 40.6812 Tflops 10.19%
fp16 8 20 128 2K FALSE 0.2 9.8889 ms, 34.7458 Tflops 21.7228 ms, 39.5434 Tflops 8.37536 ms, 41.0248 Tflops 21.0789 ms, 40.7513 Tflops 10.08%
fp16 8 32 128 2K FALSE 0.2 15.9766 ms, 34.4101 Tflops 34.726 ms, 39.5781 Tflops 13.31 ms, 41.304 Tflops 33.6967 ms, 40.787 Tflops 10.95%
fp16 8 40 128 2K FALSE 0.2 19.9986 ms, 34.3622 Tflops 43.441 ms, 39.5476 Tflops 16.6432 ms, 41.2898 Tflops 42.1168 ms, 40.791 Tflops 11.06%

danyao12 avatar Aug 11 '23 10:08 danyao12

Is it possible to merge it with current develop branch and solve all the conflicts? I'd like to use MHA kernels in MIOpen library, but this branch is far too old.

CAHEK7 avatar Oct 20 '23 13:10 CAHEK7