tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Tracking Issue] [WebGPU] Supporting DP4A in WebGPU backend

Open Jiawei-Shao opened this issue 1 year ago • 7 comments

This issue is to track progress for DP4A (new built-in functions introduced in packed_4x8_integer_dot_product) on WebGPU Backend

  • [ ] P1. Title of this piece of the feature (PR link if available)

Triage

  • backend: WebGPU
  • needs-triage

Jiawei-Shao avatar Feb 22 '24 10:02 Jiawei-Shao

Thanks @Jiawei-Shao ! contribution is more than welcomed

tqchen avatar Feb 27 '24 19:02 tqchen

Hi @tqchen,

I am stuck on the translation of int8 to WGSL so I have to turn to you for help.

Currently WGSL doesn't support 8-bit integers, so in the output WGSL we can only load and store a uint32s as int8x4. What can I do to ask TVM to generate TIR that always loads and stores 4 int8 together instead of loading or storing int8 separately?

For example,

block_read_0 = sch.cache_read(block, 0, "shared")

will generate below TIR

X_shared_1 = T.Buffer((228,), "int8", data=X_shared, scope="shared")
for ax0, ax1 in T.grid(57, 4):
    X_1 = T.Buffer((16384,), "int8", data=X.data)
    X_shared_1[ax0 * 4 + ax1] = X_1[blockIdx_y * 8192 + i_2 * 128 + ax0 * 128 + k_0 * 4 + ax1]

And it will generate below WGSL (X is an int8 array, which is incorrect in WGSL):

for (var ax1 : i32 = 0; ax1 < 4i; ax1++) {
    X_shared[((ax0 * 4i) + ax1)] = X[(((((i32(blockIdx.y) * 8192i) + (i_2 * 128i)) + (ax0 * 128i)) + (k_0 * 4i)) + ax1)];
}

While I'd like to have below WGSL (X_shared and X are both uint32 arrays):

X_shared[ax0] = X[((i32(blockIdx.y) * 2048i) + (i_2 * 32i)) + (ax0 * 32i)) + (k_0))];

The declarations of X_shared and X are easy to handle (for example, I can change var<workgroup> X_shared : array<i8, 228> to var<workgroup> X_shared : array<u32, 57>) while the loop in the load of int8 data is difficult to handle. Could you give me some advice on how to do it?

Here is the python script I used for test. I cannot upload .py files so I just renamed it to .txt. tvm_dp4a.txt

Jiawei-Shao avatar Mar 20 '24 08:03 Jiawei-Shao

@Jiawei-Shao , in this case. i think we can go with directly creating an array of "int8x4"(aka the vector type), so all loading and store are vectorized. and we lower int8x4 to uint32

tqchen avatar Mar 21 '24 20:03 tqchen

Another simpler approach(which could be one step easier) would be to simply first take dp4a as an intrinsic that takes in uint32 and produces the i32. That does mean that we need to write special tvm programs in uint32, but at least this would serve as a first step

tqchen avatar Mar 21 '24 20:03 tqchen

@Jiawei-Shao in this case, we can do sch.vectorize(ax1) to convert the loop to a vectorized one. https://github.com/apache/tvm/blob/main/src/target/spirv/spirv_utils.cc#L123 will rewrite buffer with vectorized access to int8x4 as long as both read and write are vectorized. You may need to use tensorization to vectorize the load from this buffer

vinx13 avatar Mar 21 '24 20:03 vinx13

import tvm
import numpy as np
from tvm.script import tir as T


@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def vector_copy(A: T.Buffer((4,), "int8x4"),
                    B: T.Buffer((4,), "int8x4")):
        T.func_attr({"global_symbol": "vector_copy", "tir.noalias": True})
        for i in T.grid(4):
            with T.block("B"):
                vi = T.axis.spatial(128, i)
                B[vi] = A[vi]

def main():

    mod = tvm.build(MyModule)
    a_np = np.arange(16).astype("int8").reshape((4, 4))
    a = tvm.nd.empty((4,), dtype="int8x4").copyfrom(a_np)
    b = tvm.nd.empty((4,), dtype="int8x4")
    mod["vector_copy"](a, b)
    print(b.numpy())


main()

tqchen avatar Mar 21 '24 20:03 tqchen

here is a simple example writing program directly in int8x4, note that the array of (4,) in int8x4 will be represented as extra dimension (4, 4) in this case, where the lowest dim come from the vector dtype

tqchen avatar Mar 21 '24 20:03 tqchen

Thanks for all your review!

We can open another issue for the support of tensorization with dp4a on WebGPU backend.

Jiawei-Shao avatar Jul 05 '24 00:07 Jiawei-Shao