tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[Feature Request] Support inputs to be a list of tensors

Open smallscientist1 opened this issue 7 months ago • 1 comments

How do I implement a kernel with the arguments a list of tensors?

When implementing the grouped gemm kernel, I want the weight input to be a list of tensors. I tried to implement the kernel in this way: construct a tensor of data_ptr() as input and then use T.make_tensor to reconstruct the tensor in the kernel.

    @T.prim_func
    def kernel(
            A: T.Tensor([batch_sum, K], dtype),  # type: ignore
            B_ptr: T.Tensor([batch_count], "int64"), # type: ignore
            # B: T.Tensor([batch_count, K, N], dtype),  # type: ignore
            C: T.Tensor([batch_sum, N], dtype),  # type: ignore
    ):
        # put here also cause an error
        # B_tensor_0 = T.make_tensor(B_ptr[0], [K, N], dtype=dtype)
        with T.Kernel(
                T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
                threads=threads) as (bx, by):
            B_tensor_0 = T.make_tensor(B_ptr[0], [K, N], dtype=dtype)

However, it encountered a compilation error.

smallscientist1 avatar May 28 '25 12:05 smallscientist1

"Inputs to be a list of tensors"

I also want to have such feature.

Any updates on this?

a1600012888 avatar Oct 11 '25 20:10 a1600012888