composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

[Question] Register data layout for two consecutive GEMMs in flash attention kernel (how is TransposedC implemented)?

Open bulffi opened this issue 1 year ago • 1 comments

Hi, CK team, Thanks for your effort to create such a cool library!🙏 I have a dumb question regarding the flash attention kernel on AMD GPU ISA.

As we know the flash attention library consists of gemm_0, some reduction operations and then gemm_1. The results of gemm_0 is held in register as a static_distributed_tensor and never written back to LDS during the reduction operations and lastly is used as the first input (A matrix) to gemm_1. However, when looking into the register data layout of __builtin_amdgcn_mfma_f32_32x32x8f16, it takes small 'rows' of A matrix and small 'cols' of B matrix and then produces small 'cols' of C matrix. This means the layout of C is kind of transposed into a col major fashion.

And I was thinking that the layout of the output of gemm_0 has to match the layout of the first input of gemm_1 for this algorithm to work efficiently. Not only for gemm_1, reductions like row max and row sum is also a row based operation, and this seems to disagree with the fact that __builtin_amdgcn_mfma_f32_32x32x8f16 has a col major layout output.

Like in CUDA there is a PTX operation called SM80_16x8x16_F32F16F16F32_TN which has the desired property that I mentioned above (output of the GEMM has the same layout as the first input of that GEMM). image

I tried to find relative operations in the ck_tile implementations of flash attention kernel. I do noticed that both gemm_0 and gemm_1 is annotated with a type called WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution, which seems to be dealing with the that fact that the C matrix should be transposed to be useful.

However, my understanding about composable kernels is still very limited and I could not find any cross lane operations that actually implemented this transpose. I do noticed there exists transposed_vectors and warp_shuffle, but I feel they are not directly dealing with the problem of transposing C.

Could you please help me out by pointing to where exactly does this transpose happen? Or there is a gap in my understanding that this transpose is not necessary and you are able to finish the flash attention operation without ever transposing the S (also P latter) matrix?

Thanks for your help in advance!

bulffi avatar Sep 15 '24 19:09 bulffi

@bulffi Internal ticket has been created to assist with your question. Thanks!

ppanchad-amd avatar Sep 25 '24 18:09 ppanchad-amd

Hi @bulffi the transpose is just simply swap the A&B pointer of the input mfma instruction, which will result in transposed layout of C matrix. You can refer to the ISA doc for CDNA1&2&3 on https://gpuopen.com/ website. Alternatively, there is an example to play with the layout here : https://github.com/carlushuang/gcnasm/tree/master/matrix_core, please refer to the comment

carlushuang avatar Oct 09 '24 11:10 carlushuang

Hi @bulffi, close the issue as answered. If you have further question, please let us know.

huanrwan-amd avatar Oct 17 '24 19:10 huanrwan-amd