composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

Tile engine for streamk

Open CongMa13 opened this issue 1 month ago • 0 comments

Proposed changes

We copied the tile engine code of gemm and revise it to align with the StreamK.

Note: The entire tile engine will be refactored to extract common functionality.

  • Use new Streamk implementation as the cpp template in .
  • Add reduction_strategy to default_config.json so that there are instances for atomic and reduction
  • Add persistent==true to default_config.json

benchmark instances are

... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_4x64x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_32x32x32

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • [ ] I have added inline documentation which enables the maintainers with understanding the motivation
  • [ ] I have removed the stale documentation which is no longer relevant after this pull request
  • [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • [x] I have run clang-format on all changed files
  • [x] Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

CongMa13 avatar Nov 04 '25 19:11 CongMa13