composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

[Issue]: Very slow perf for Gemm BF16

Open ghostplant opened this issue 6 months ago • 9 comments

To reproduce:

Command: ./bin/ckProfiler gemm 2 1 1 2 0 1 32 512 7168 -1 -1 -1 3 100

GPU Type: MI300x

Searched Perf: Best Perf for datatype = bf16 ALayout = RowMajor BLayout = ColumnMajor M = 32 N = 512 K = 7168 StrideA = 7168 StrideB = 7168 StrideC = 512 : 0.0634091 ms, 3.70421 TFlops, 123.508 GB/s, DeviceGemm_Xdl_CShuffle<Default, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, 8, 8, 1, 1> LoopScheduler: Default, PipelineVersion: v2

ghostplant avatar Jun 02 '25 00:06 ghostplant

For rocBLAS, the case gets 0.009 ms.

ghostplant avatar Jun 02 '25 00:06 ghostplant

Hi @ghostplant. Internal ticket has been created to investigate this issue. Thanks!

ppanchad-amd avatar Jun 02 '25 14:06 ppanchad-amd

Hi @ghostplant Thanks for reporting this. I reproduced the numbers and saw a similar gap between ckProfiler and rocblas-bench for this GEMM shape. From what I see here is, the discrepancy is expected due to several reasons:

  • ckProfiler It's primarily optimized for throughput. CK kernels are generally tuned for large GEMMs and training workloads. For small M/N cases like M=32, CK may not select the most optimal kernel out of the box.

  • rocBLAS, on the other hand, includes specialized low-latency kernels for small GEMMs and performs less runtime tuning, leading to much faster execution in these scenarios.

Additionally, CK may insert layout transformations or padding internally depending on the kernel, which also impacts timing.

If your use case is latency-sensitive (e.g., inference), rocBLAS is a better fit for small GEMMs. CK excels in high-throughput, large-batch training scenarios.

adityas-amd avatar Jul 11 '25 17:07 adityas-amd

@adityas-amd In large GEMM, ckProfiler is also slower than rocBLAS.

We need ckProfiler mainly for some custom-purpose gemm fusion/quant that rocBLAS cannot achieve, regardless of large or small gemm. Can you provide the instructions how to add more kernel selection choices to improve ckProfilers's perf on the above common shape? Thanks.

ghostplant avatar Jul 11 '25 17:07 ghostplant

To add more kernel choices to improve ckProfiler performance for the shape M=32, N=512, K=7168 (bf16, RowMajor A, ColumnMajor B):

ck/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_cshuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp https://github.com/ROCm/composable_kernel/blob/e6104daecc7e29d26fc0435dd697132bdd262163/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp This function registers all supported kernel variants for this layout and datatype.

void add_device_gemm_xdl_cshuffle_bf16_bf16_bf16_mk_nk_mn_instances( std::vector<std::unique_ptr<DeviceGemmPtr>>& instances)

Inside the function above, you can add more kernel variants using: Here you can add or duplicate kernel instances: using DeviceOpInstance = DeviceGemmXdlCShuffle< BlockSize, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, PipelineVer>

Refer to existing instances in the file for syntax and parameter patterns. you may tune the following parameters BlockSize MPerBlock, NPerBlock, KPerBlock MPerXdl, NPerXdl ABlockTransferThreadSliceLengths, etc.

After modifying or adding new entries, rebuild using cmake -S . -B build && cmake --build build -j and re-run ckProfiler. The profiler will now include your customized kernels.

adityas-amd avatar Aug 01 '25 16:08 adityas-amd

Thanks, however, I just recompile the latest CK, and it failed to complete building, so I cannot try your suggested fix. I have no idea if some recent commits break the compilation for non-gfx950 GPUs:

/root/composable_kernel/include/ck_tile/core/numeric/pk_fp4.hpp:241:40: error: constexpr function never produces a constant expression [-Winvalid-constexpr]                                               
  241 | CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x)                                                                                                                                 
      |                                        ^~~~~~~~~~~~~~~~                                                                                                                                                    
/root/composable_kernel/include/ck_tile/core/numeric/pk_fp4.hpp:246:61: note: subexpression not valid in a constant expression                                                                             
  246 |     return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),                                                                                                                                        
      |                                                             ^~~~                                                                                                                                           
/root/composable_kernel/include/ck_tile/core/numeric/pk_fp4.hpp:250:40: error: constexpr function never produces a constant expression [-Winvalid-constexpr]                                               
  250 | CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x)                                                                                                                                 
      |                                        ^~~~~~~~~~~~~~~~                                                              
/root/composable_kernel/include/ck_tile/core/numeric/pk_fp4.hpp:255:61: note: subexpression not valid in a constant expression                                                                             
  255 |     return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),                                           
      |                                                             ^~~~                                                            
/root/composable_kernel/include/ck_tile/core/numeric/pk_fp4.hpp:259:40: error: constexpr function never produces a constant expression [-Winvalid-constexpr]
  259 | CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x)                            
      |                                        ^~~~~~~~~~~~~~~~                                                 
/root/composable_kernel/include/ck_tile/core/numeric/pk_fp4.hpp:264:41: note: subexpression not valid in a constant expression                                                                             
  264 |     return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1])); 

ghostplant avatar Aug 04 '25 11:08 ghostplant

Hello @ghostplant let me reproduce it again, I did not have this issue earlier, are you seeing this on main develop branch?

adityas-amd avatar Aug 14 '25 16:08 adityas-amd

Hello @ghostplant Try universal gemm with splitK: ./bin/ckProfiler gemm_universal 2 1 1 2 0 1 32 512 7168 -1 -1 -1 -1 3 100 0 From my measurement it is 15x times better (selected splitK=16). With such splitK value you can see some lower accuracy.

For higher accuracy you can also disable splitK or decrease value of splitK. Example with splitK = 1: ./bin/ckProfiler gemm_universal 2 1 1 2 0 1 32 512 7168 -1 -1 -1 1 3 100 0 Then it is 3x times better.

bartekxk avatar Aug 26 '25 21:08 bartekxk

Hi, @bartekxk, may I know what argument value below stands for splitK=16?

./bin/ckProfiler gemm_universal 2 1 1 2 0 1 32 512 7168 -1 -1 -1 -1 3 100 0

If I directly run the command, I gets a lot of "incorrect errors":

Perf: 0.00627527 ms, 37.4296 TFlops, 1248 GB/s, DeviceGemmXdlUniversal<KPadding, RCR> BlkSize: 128, BlkTile: 16x128x64, WaveTile: 16x16, WaveMap: 1x4, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2, BlkGemmPipelinePrefetchStages: 2, Kpack: 8, KBatch 19
Error: Incorrect results!        out[20] != ref[20]: -0.1875 != -0.2597656
Error: Incorrect results!        out[39] != ref[39]: -0.359375 != -0.4628906
Error: Incorrect results!        out[41] != ref[41]: -1.046875 != -1.1875
Error: Incorrect results!        out[82] != ref[82]: -1.484375 != -1.65625
max err: 0.3046875, number of errors: 397, 2.423096% wrong values
Perf: 0.00593449 ms, 39.579 TFlops, 1319.67 GB/s, DeviceGemmXdlUniversal<KPadding, RCR> BlkSize: 128, BlkTile: 16x128x64, WaveTile: 16x16, WaveMap: 1x4, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v2, BlkGemmPipelinePrefetchStages: 2, Kpack: 8, KBatch 32
Error: Incorrect results!        out[20] != ref[20]: -0.390625 != -0.2597656
Error: Incorrect results!        out[39] != ref[39]: -0.28125 != -0.4628906
Error: Incorrect results!        out[153] != ref[153]: 0.1523438 != 0.1113281
Error: Incorrect results!        out[166] != ref[166]: -0.9882812 != -0.890625
max err: 0.3828125, number of errors: 419, 2.557373% wrong values

ghostplant avatar Sep 13 '25 04:09 ghostplant