tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[WebGPU] Support `dot4I8Packed(int8x4, int8x4)` as a pure extern method

Open Jiawei-Shao opened this issue 1 year ago • 0 comments

This patch adds the support of dot4I8Packed(int8x4, int8x4) as a pure extern method of WebGPU target. In the generated WGSL shader, int8x4 will be translated into u32, and dot4I8Packed(int8x4, int8x4) will be translated into the WGSL built-in function dot4I8Packed(u32, u32).

Here is an example to use dot4I8Packed in WebGPU target:

n = te.var("n")
A = te.placeholder((n,), "int8x4", name="A")
B = te.placeholder((n,), "int8x4", name="B")
C = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("int32", "dot4I8Packed", A[i], B[i]), name="C")
s = te.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=64)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
mod = tvm.build(s, [A, B, C], tgt, name="dp4aTest")

Issue: #16627

Jiawei-Shao avatar May 08 '24 07:05 Jiawei-Shao