composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

[CK Tile][Critical][Performance] Slow CK Tile GEMM compared to universal_gemm in the old CK

Open zjing14 opened this issue 1 year ago • 3 comments

Tried CK Tile GEMM with V3 pipeline (https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm.cpp) for compute bound cases (i.e., M = 4096, N = 4096 and K = 4096), but get much worse performance than (https://github.com/ROCm/composable_kernel/blob/develop/example/01_gemm/gemm_xdl_bf16_v3.cpp) with the same tile size 256x256x64.

CK Tile V3 (359.483 Tflops)

./bin/tile_example_universal_gemm -m=4096 -n=4096 -k=4096 -v=0
Launching kernel with args: grid: {16, 16, 1}, blocks: {256, 1, 1}
Run Gemm kernel with M =4096 N =4096 K =4096 StrideA =4096 StrideB =4096 StrideC =4096 : 0.382324 ms, 359.483 TFlops, 263.293 GB/s,

vs. Old CK GEMM V3 (615.46 TFlops)

./bin/example_gemm_xdl_bf16_v3 0 2 1 4096 4096 4096 4096 4096 4096 1
a_m_k: dim 2, lengths {4096, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
RotatingMemWrapper: { size_a: 33554432, size_b: 33554432, rotating_count: 4}
Perf: 0.223311 ms, 615.46 TFlops, 450.776 GB/s, DeviceGemmXdlUniversal<Default, RCR> BlkSize: 256, BlkTile: 256x256x64, WaveTile: 32x32, WaveMap: 4x4, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2

zjing14 avatar Dec 06 '24 06:12 zjing14

@carlushuang @aosewski

zjing14 avatar Dec 06 '24 06:12 zjing14

Hi @zjing14. Internal ticket has been created to investigate your issue. Thanks!

ppanchad-amd avatar Dec 06 '24 15:12 ppanchad-amd

@zjing14 We're in about 80% of this instance right now on develop:

ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
    ALayout,   BLayout,  CLayout,   
  ADataType,   BDataType,  CDataType,  AccDataType,  CShuffleDataType, 
  PassThrough, PassThrough, PassThrough,       GemmDefault,
  256,
  256,   256,
  64,   8,   8,
  32,   32,
  4,    4,
  S<4, 64, 1>,     S<1, 0, 2>,    S<1, 0, 2>,
  2, 8, 8, 0,
  S<4, 64, 1>,     S<1, 0, 2>,    S<1, 0, 2>,
  2, 8, 8, 0,
  1, 1, S<1, 32, 1, 8>, 8,
  ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;

Just need to enlarge K_Tile to 64: constexpr ck_tile::index_t K_Tile = 64;

aosewski avatar Jan 31 '25 12:01 aosewski