tvm
tvm copied to clipboard
[WebGPU] Support warp-level shuffle primitives with subgroup
Overview
This PR supports warp-level shuffle primitives using the newly introduced subgroup in WebGPU. We then use them in the implementation of allreduce lowering.
The introduced primitives are:
subgroupShuffle()subgroupShuffleUp()subgroupShuffleDown()
This PR largely follows the Metal counterpart:
- https://github.com/apache/tvm/pull/15401
Tested with Llama3.2-1B-q4f16_1 and Llama3.1-8B-q4f16_1 E2E with WebLLM. The dumped WebGPU kernel indeed contains subgroup shuffle primitives: https://gist.github.com/CharlieFRuan/cb54a8db0513ecbbc16c5de8df5ab845
Remaining TODOs
- [ ] Benchmark speedup
- [ ] Be able to parameterize whether to use subgroup or not when targeting WebGPU, since not all devices support it
- [ ] Check
GPUFeatureName's inclusion ofsubgroupsin@webgpu/types - [ ] Some WebGPU devices can have > 256 max num thread per block, be able to target different kinds
Resources
- https://github.com/gpuweb/gpuweb/blob/859fdd4a803a11e7b8de70483aa75c365be18b0e/proposals/subgroups.md
- https://www.w3.org/TR/WGSL/#subgroup-builtin-functions
- https://developer.chrome.com/blog/new-in-webgpu-134?hl=en