composable_kernel
composable_kernel copied to clipboard
Tile engine for streamk
Proposed changes
We copied the tile engine code of gemm and revise it to align with the StreamK.
Note: The entire tile engine will be refactored to extract common functionality.
- Use new Streamk implementation as the cpp template in .
- Add
reduction_strategyto default_config.json so that there are instances foratomicandreduction - Add
persistent==trueto default_config.json
benchmark instances are
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_4x64x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_16x16x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_32x32x8
... benchmark_gemm_streamk_fp16_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_4x64x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_1x4x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_2x2x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_atomic_256x256x32_4x1x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_1x4x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_2x2x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_False_reduction_256x256x32_4x1x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_1x4x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_2x2x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_atomic_256x256x32_4x1x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_1x4x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_2x2x1_32x32x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_16x16x32
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_32x32x16
... benchmark_gemm_streamk_fp8_rcr_compv3_cshuffle_intrawave_False_False_False_True_reduction_256x256x32_4x1x1_32x32x32
Checklist
Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.
- [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant after this pull request
- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
- [x] I have run
clang-formaton all changed files - [x] Any dependent changes have been merged
Discussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered