composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

Sequence length 1 GEMV alternative for fused attention

Open cloudhan opened this issue 2 years ago • 1 comments

Sequence length 1 is extremely important for decoding (ASR, text generation, etc)

In onnxruntime, we found the rocblas gemm + sofmax kernel +rocblas gemm is much faster for this case,

> KERNEL_EXPLORER_BUILD_DIR=./ python ../../onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py float16 2 1 1500 6 64 0 0 --scale 0.125
 27.26 us  0.17 tflops float16 B=2 S=1 T=1500 N=6 H=64, Generic   # <------------- this is rocblas gemm + sofmax kernel +rocblas gemm
187.71 us  0.02 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 64, 32, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 256, 128, 32, 8, 8, 256, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 256, 128, 32, 8, 8, 256, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 256, 32, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 32, 8, 8, 64, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 32, 8, 8, 64, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 64, 8, 8, 64, 128, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
not supported          float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 64, 256, 64, 8, 8, 64, 64, 32, 2, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
185.33 us  0.03 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
185.19 us  0.03 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 128, 64, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
187.60 us  0.02 tflops float16 B=2 S=1 T=1500 N=6 H=64, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<256, 128, 64, 32, 8, 8, 128, 128, 32, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>

The shape for the previous shape is as follows

a_gs_ms_ks: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}
b0_gs_ns_ks: dim, 4, lengths {2, 6, 1500, 64}, strides {576000, 96000, 64, 1}
b1_gs_os_ns: dim, 4, lengths {2, 6, 64, 1500}, strides {576000, 96000, 1, 64}
c_gs_ms_os: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}

Another cases are

a_gs_ms_ks: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}
b0_gs_ns_ks: dim, 4, lengths {2, 6, 21, 64}, strides {49152, 8192, 64, 1}  <---- 21 will increase during decoding 
b1_gs_os_ns: dim, 4, lengths {2, 6, 64, 21}, strides {49152, 8192, 1, 64}  <---- 21 will increase during decoding 
c_gs_ms_os: dim, 4, lengths {2, 6, 1, 64}, strides {384, 64, 384, 1}

It seems current fused attention pad the matrices and calls into tensor cores in any case, hence, wasting of computing power for smaller sequence length. We might need DeviceBatchedGemvSoftmaxGemvPermute variant in this case.

cloudhan avatar Jun 30 '23 11:06 cloudhan

@cloudhan Apologies for the lack of response. Can you please check if this is an issue still with the latest ROCm 6.2? If not, please close the ticket. Thanks!

ppanchad-amd avatar Aug 20 '24 20:08 ppanchad-amd