[CK Tile][Critical][Performance] Slow CK Tile GEMM compared to universal_gemm in the old CK
Tried CK Tile GEMM with V3 pipeline (https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm.cpp) for compute bound cases (i.e., M = 4096, N = 4096 and K = 4096), but get much worse performance than (https://github.com/ROCm/composable_kernel/blob/develop/example/01_gemm/gemm_xdl_bf16_v3.cpp) with the same tile size 256x256x64.
CK Tile V3 (359.483 Tflops)
./bin/tile_example_universal_gemm -m=4096 -n=4096 -k=4096 -v=0
Launching kernel with args: grid: {16, 16, 1}, blocks: {256, 1, 1}
Run Gemm kernel with M =4096 N =4096 K =4096 StrideA =4096 StrideB =4096 StrideC =4096 : 0.382324 ms, 359.483 TFlops, 263.293 GB/s,
vs. Old CK GEMM V3 (615.46 TFlops)
./bin/example_gemm_xdl_bf16_v3 0 2 1 4096 4096 4096 4096 4096 4096 1
a_m_k: dim 2, lengths {4096, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
RotatingMemWrapper: { size_a: 33554432, size_b: 33554432, rotating_count: 4}
Perf: 0.223311 ms, 615.46 TFlops, 450.776 GB/s, DeviceGemmXdlUniversal<Default, RCR> BlkSize: 256, BlkTile: 256x256x64, WaveTile: 32x32, WaveMap: 4x4, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2
@carlushuang @aosewski
Hi @zjing14. Internal ticket has been created to investigate your issue. Thanks!
@zjing14 We're in about 80% of this instance right now on develop:
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
256, 256,
64, 8, 8,
32, 32,
4, 4,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
Just need to enlarge K_Tile to 64:
constexpr ck_tile::index_t K_Tile = 64;