composable_kernel
composable_kernel copied to clipboard
FlashAttention train kernels
FlashAttentionV1: forward kloop: gridwise_batched_mha_fwd_xdl_cshuffle_v1.hpp backward kloop prototype1: gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp backward kloop prototype2: gridwise_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
FlashAttentionV2: forward kloop: gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp backward qloop from bottom to top prototype1: gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp backward qloop from bottom to top prototype2: gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp backward qloop from bottom to top prototype1 w/o d calculation: gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp backward qloop from bottom to top prototype2 w/o d calculation: gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp backward qloop d = rowsum(do*o): gridwise_batched_mha_bwd_xdl_cshuffle_qloop_ydotygrad.hpp
Tasks in progress:
- [x] Qloop bwd 2 split kernels waiting to merge
- [x] Bias function
- [x] CausalFromBottomRight
- [x] Grouped Query Attention/Multi Query Attention
Using the following script for testing Run inferring examples
#!/bin/bash
TESTS=" example_batched_gemm_scale_softmax_gemm_xdl_fp16 \
example_batched_gemm_scale_softmax_gemm_xdl_bf16 \
example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 \
example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 \
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 \
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 \
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16"
FAILED_TESTS=""
for test in $TESTS; do
./bin/$test;
if test $? -eq 0 ; then
echo $test succeeded!;
else
FAILED_TESTS="$FAILED_TESTS $test";
fi;
done
for test in $FAILED_TESTS; do
echo $test failed;
done
Run training examples
#!/bin/bash
TESTS=" example_grouped_multihead_attention_forward_v1 \
example_batched_multihead_attention_forward_v1 \
example_grouped_multihead_attention_backward_v1 \
example_batched_multihead_attention_backward_v1 \
example_grouped_multihead_attention_train_v1 \
example_batched_multihead_attention_train_v1 \
example_grouped_multihead_attention_forward_v2 \
example_batched_multihead_attention_forward_v2 \
example_grouped_multihead_attention_backward_v2 \
example_batched_multihead_attention_backward_v2 \
example_grouped_multihead_attention_train_v2 \
example_batched_multihead_attention_train_v2 \
example_grouped_multihead_attention_backward_v3 \
example_batched_multihead_attention_backward_v3"
FAILED_TESTS=""
for test in $TESTS; do
./bin/$test;
if test $? -eq 0 ; then
echo $test succeeded!;
else
FAILED_TESTS="$FAILED_TESTS $test";
fi;
done
for test in $FAILED_TESTS; do
echo $test failed;
done
Because qloop's fwd and bwd use different layouts, so we refactor dropout to decouple fwd and bwd, but dropout after refactor brought a lot of overhead to fwd, there are three main reasons: 1. Because the shuffle of dropout is performed in 4*4 units, it wastes the get_random_8x16 instruction, which generates eight random numbers per call by each thread, but uses only the first four of them, 2. Transpose shuffle of random number matrix through LDS will cause bankconflict, 3. Fwd headdim128 has some scratch before and after dropout refactor.
I haven't thought of a good solution for point2. For point1, I optimized dropout using 16*16 units to make full use of the get_random_8x16 instruction, for example, random numbers are exchanged between Lane0-15 and Lane32-47 through LDS in fwd kernel. For point3, when we set the Gemm1NPerBlock to headdim(128), the TotalNumVgprs will already be close to 250 even without dropout, so it will be difficult to optimize dropout to eliminate scratch, so in this case I tried to halve the Gemm1NPerBlock (64), which will double the number of workgroups, but will eliminate scratch. It should be noted that when lse and dropout random martix are saved, do not write out repeatedly.
The following is a comparison before and after optimization. I think later tuning kernels with dropout using CK profiler will get better performance.
Dtype | Device Train Microbatch size | Num_heads | Head_dim | Seqlen | Casual | Dropout | Fwd Before | Bwd Before | Fwd After | Bwd After | Improved |
---|---|---|---|---|---|---|---|---|---|---|---|
fp16 | 16 | 12 | 64 | 1K | FALSE | 0.2 | 1.39495 ms, 36.9472 TFlops | 3.60971 ms, 35.6952 Tflops | 1.09508 ms, 47.0647 Tflops | 3.09402 ms, 41.6445 TFlops | 22.12% |
fp16 | 16 | 12 | 64 | 2K | FALSE | 0.2 | 5.38146 ms, 38.309 TFlops | 13.9454 ms, 36.958 Tflops | 4.22661 ms, 48.7764 Tflops | 11.8978 ms, 43.3185 Tflops | 22.36% |
fp16 | 8 | 16 | 64 | 1K | FALSE | 0.2 | 0.941686 ms, 36.4875 Tflops | 2.41688 ms, 35.5414 Tflops | 0.741349 ms, 46.3476 Tflops | 2.078 ms, 41.3375 Tflops | 21.74% |
fp16 | 8 | 16 | 64 | 2K | FALSE | 0.2 | 3.60951 ms, 38.0769 Tflops | 9.31209 ms, 36.898 Tflops | 2.8365 ms, 48.4537 Tflops | 7.93545 ms, 43.2991 Tflops | 22.38% |
fp16 | 8 | 16 | 64 | 4K | FALSE | 0.2 | 14.1361 ms, 38.8901 Tflops | 36.5511 ms, 37.6019 Tflops | 11.2258 ms, 48.9723 TFlops | 31.1255 ms, 44.1564 Tflops | 21.75% |
fp16 | 4 | 16 | 128 | 2K | FALSE | 0.2 | 4.00785 ms, 34.2924 Tflops | 8.70397 ms, 39.4759 Tflops | 3.3851 ms, 40.6012 Tflops | 8.44611 ms, 40.6812 Tflops | 10.19% |
fp16 | 8 | 20 | 128 | 2K | FALSE | 0.2 | 9.8889 ms, 34.7458 Tflops | 21.7228 ms, 39.5434 Tflops | 8.37536 ms, 41.0248 Tflops | 21.0789 ms, 40.7513 Tflops | 10.08% |
fp16 | 8 | 32 | 128 | 2K | FALSE | 0.2 | 15.9766 ms, 34.4101 Tflops | 34.726 ms, 39.5781 Tflops | 13.31 ms, 41.304 Tflops | 33.6967 ms, 40.787 Tflops | 10.95% |
fp16 | 8 | 40 | 128 | 2K | FALSE | 0.2 | 19.9986 ms, 34.3622 Tflops | 43.441 ms, 39.5476 Tflops | 16.6432 ms, 41.2898 Tflops | 42.1168 ms, 40.791 Tflops | 11.06% |
Is it possible to merge it with current develop branch and solve all the conflicts? I'd like to use MHA kernels in MIOpen library, but this branch is far too old.