composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

[CK_TILE] implement basic flatmm

Open feifei14119 opened this issue 10 months ago • 2 comments

Proposed changes

add basic flatmm based on ck_tile:

  • flatmm is placed in a seperate example folder
  • flatmm is using dependent kernel and pipeline and block function
  • flatmm is designed to re-use gemm warp funcitons

in this change, we only implement basic flatmm function and framwork:

  • support fp/bf16 input datatype
  • support 1x4 warp shape with different tile size

furture feature will be added in following changes:

  • support fp8/bf8 data type
  • support block scale
  • support more tile zie, warp shape and mfma instr

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • [x] I have added inline documentation which enables the maintainers with understanding the motivation
  • [ ] I have removed the stale documentation which is no longer relevant after this pull request
  • [x] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • [x] I have run clang-format on all changed files
  • [ ] Any dependent changes have been merged

Discussion

not sure if we need to support inline asm for ck_tile flatmm

feifei14119 avatar Feb 08 '25 09:02 feifei14119

add some smoke tests? @feifei14119

coderfeli avatar Mar 19 '25 05:03 coderfeli

162.674 on 308. ./bin/tile_example_flatmm_basic -m=5120 -n=5120 -k=8192 Run Flatmm kernel with M =5120 N =5120 K =8192 StrideA =8192 StrideB =8192 StrideC =5120 : 2.64023 ms, 162.674 TFlops, 83.4022 GB/s, Relative error threshold: 0.000488281 Absolute error threshold: 0.344238 The CPU veification result is:correct

coderfeli avatar Mar 19 '25 08:03 coderfeli