tilelang
tilelang copied to clipboard
[Feature Request] Support inputs to be a list of tensors
How do I implement a kernel with the arguments a list of tensors?
When implementing the grouped gemm kernel, I want the weight input to be a list of tensors.
I tried to implement the kernel in this way: construct a tensor of data_ptr() as input and then use T.make_tensor to reconstruct the tensor in the kernel.
@T.prim_func
def kernel(
A: T.Tensor([batch_sum, K], dtype), # type: ignore
B_ptr: T.Tensor([batch_count], "int64"), # type: ignore
# B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore
):
# put here also cause an error
# B_tensor_0 = T.make_tensor(B_ptr[0], [K, N], dtype=dtype)
with T.Kernel(
T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
threads=threads) as (bx, by):
B_tensor_0 = T.make_tensor(B_ptr[0], [K, N], dtype=dtype)
However, it encountered a compilation error.
"Inputs to be a list of tensors"
I also want to have such feature.
Any updates on this?