composable_kernel
composable_kernel copied to clipboard
Sequence length 1 GEMV alternative for fused attention
Sequence length 1 is extremely important for decoding (ASR, text generation, etc)
In onnxruntime, we found the rocblas gemm + sofmax kernel +rocblas gemm is much faster for this case,
> KERNEL_EXPLORER_BUILD_DIR=./ python ../../onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py float16 2 1 1500 6 64 0 0 --scale 0.125
27.26 us 0.17 tflops float16 B=2 S=1 T=1500 N=6 H=64, Generic # <------------- this is rocblas gemm + sofmax kernel +rocblas gemm
187.71 us 0.02 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 64, 32, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 256, 128, 32, 8, 8, 256, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 256, 128, 32, 8, 8, 256, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 256, 32, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 32, 8, 8, 64, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 32, 8, 8, 64, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 64, 8, 8, 64, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 64, 8, 8, 64, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
185.33 us 0.03 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
185.19 us 0.03 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
187.60 us 0.02 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 64, 32, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
The shape for the previous shape is as follows
a_gs_ms_ks: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}
b0_gs_ns_ks: dim, 4, lengths {2, 6, 1500, 64}, strides {576000, 96000, 64, 1}
b1_gs_os_ns: dim, 4, lengths {2, 6, 64, 1500}, strides {576000, 96000, 1, 64}
c_gs_ms_os: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}
Another cases are
a_gs_ms_ks: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}
b0_gs_ns_ks: dim, 4, lengths {2, 6, 21, 64}, strides {49152, 8192, 64, 1} <---- 21 will increase during decoding
b1_gs_os_ns: dim, 4, lengths {2, 6, 64, 21}, strides {49152, 8192, 1, 64} <---- 21 will increase during decoding
c_gs_ms_os: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}
It seems current fused attention pad the matrices and calls into tensor cores in any case, hence, wasting of computing power for smaller sequence length. We might need DeviceBatchedGemvSoftmaxGemvPermute variant in this case.
@cloudhan Apologies for the lack of response. Can you please check if this is an issue still with the latest ROCm 6.2? If not, please close the ticket. Thanks!