tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[WebGPU] Support warp-level shuffle primitives with subgroup

Open CharlieFRuan opened this issue 8 months ago • 0 comments

Overview

This PR supports warp-level shuffle primitives using the newly introduced subgroup in WebGPU. We then use them in the implementation of allreduce lowering.

The introduced primitives are:

  • subgroupShuffle()
  • subgroupShuffleUp()
  • subgroupShuffleDown()

This PR largely follows the Metal counterpart:

  • https://github.com/apache/tvm/pull/15401

Tested with Llama3.2-1B-q4f16_1 and Llama3.1-8B-q4f16_1 E2E with WebLLM. The dumped WebGPU kernel indeed contains subgroup shuffle primitives: https://gist.github.com/CharlieFRuan/cb54a8db0513ecbbc16c5de8df5ab845

Remaining TODOs

  • [ ] Benchmark speedup
  • [ ] Be able to parameterize whether to use subgroup or not when targeting WebGPU, since not all devices support it
  • [ ] Check GPUFeatureName's inclusion of subgroups in @webgpu/types
  • [ ] Some WebGPU devices can have > 256 max num thread per block, be able to target different kinds

Resources

  • https://github.com/gpuweb/gpuweb/blob/859fdd4a803a11e7b8de70483aa75c365be18b0e/proposals/subgroups.md
  • https://www.w3.org/TR/WGSL/#subgroup-builtin-functions
  • https://developer.chrome.com/blog/new-in-webgpu-134?hl=en

CharlieFRuan avatar Mar 03 '25 03:03 CharlieFRuan