composable_kernel
composable_kernel copied to clipboard
Wmma support for grouped convolution bwd weight
Proposed changes
Summary:
- Modify gridwise implementation to work with convolution (grid descriptors are not created internally but passed from the device level)
- Add device level implementation:
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3,DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3andDeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 - Add device implementation of batched gemm multiple Ds (needed for explicit gemm - conv bwd weight)
- Adapt existing device implementation of explicit gemm to work for both xdl and wmma implementations of batched gemm multiple Ds
- Add support for occupancy-based splitk for one stage and two stage implementations of grouped conv bwd weight
- Create instances
- Add examples
- Remove old instances (they don't support splitk)
- Add tests for bwd weight scale
The implementations are based on CShuffleV3 but the functionality is the same as xdl.
Checklist
Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.
- [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
- [x] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers with understanding the motivation
- [x] I have removed the stale documentation which is no longer relevant after this pull request
- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
- [x] I have run
clang-formaton all changed files - [x] Any dependent changes have been merged
Discussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered