composable_kernel
composable_kernel copied to clipboard
Missing non-biased and bias and masked version client API for flash attention
Current client API ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute only has 1 D0 and exact 1 D0 version build into the client library. The non-bias version have no instance builtin. Please add instance to the library.
We also need 2 D0s version for both attentions bias and mask. The problem here is the mask in our code is of type int32_t. This actually makes the combination exponential. I can copy the instances code to specialize for our purpose for 2 D0s-case. But I think ck maintainers should take maintainance burden into account. How should we keep track of thoes instances change with ck upgrading?
@cloudhan Actually, for the non-bias version/1 D0/2 D0, we just need to change the NumDimO in our instances. We can add all combinations easily. However, it will significantly increase the compilation time. One way to reduce the maintenance burden is using code-gen to create instances based on customer needs. Currently, we are investigating the proper method to provide Python-API.
https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/59cbb20c7c4e4ccb297a818685c74885ee853206/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp#L72
https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/59cbb20c7c4e4ccb297a818685c74885ee853206/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp#L104
I know I can always copy those code and specialize for our purpose. I am just wondering, before copying load of code to ort repo, do we have better way, or offically recommend way for this scenario. Since the low-level API has not settled down, they may change in the future.
Since ORT uses Pre-build CK, we have to generate all instance combinations statically. A probable solution is CK with code-gen which can generate instances on-the-fly.
@zjing14 Chat with @ltqin on Teams. In the original issue, I'd like either:
- ck provides enough client api instances or
- ck make the following code to be customizable, except for the gemm configuration part. https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/1e59eb3be5a970ad3996d52984919a9edd7ce58f/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp#L35-L64
I'd like the second option because it allows ck team to maintain the gemm configuration and we maintain the customized instances. We cannot foresee the future requirement, there might be bf16+fp16 bias, or int8 mask with other op combinations which is exponential. Putting all these things in client api as instances is impractical.
In short, we opt for second option.
created PR626
@cloudhan Is this ticket still relevant? If not, please close the ticket. Thanks!