tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[BUG] index of T.Tensor cannot be int64

Open victbr opened this issue 2 months ago • 0 comments

Required prerequisites

What version of TileLang are you using?

0.1.6

System information

PyTorch version: 2.7.0+cu128

Problem description

When modifying a tensor, if the passed index value is of type int64, it will result in a compilation failure and throw the error: "TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32" Even if the index range could genuinely require int64. In such a scenario, if the index value is forcibly cast to int32, it will instead report the error: "tvm.error.InternalError: Check failed: arg.dtype() == value.dtype() (int32 vs. int64) : "

Reproducible example code

import torch
import tilelang
import tilelang.language as T

@tilelang.jit
def set_cache_kernel(
    S,
    B,
    D,
    dtype="float32",
):
    @T.prim_func
    def main(
        pos: T.Tensor([S,], "int64"),  # type: ignore  `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
        # pos: T.Tensor([S,], "int32"),  # type: ignore  `tvm.error.InternalError: Check failed: arg.dtype() == value.dtype() (int32 vs. int64) :`
        value: T.Tensor([S, D], dtype),  # type: ignore
        cache: T.Tensor([B, D], dtype),  # type: ignore
    ):
        with T.Kernel(S, threads=128) as bx:
            slot = pos[bx]
            for i in T.Parallel(D):
                cache[slot, i] = value[bx, i]
    return main


def set_cache():
    B = 32768*64*32*64
    D = 2
    cache = torch.rand((B, D), device="cuda", dtype=torch.float32)
    S = 10
    kernel = set_cache_kernel(S, B, D)

if __name__ == "__main__":
    set_cache()

Traceback


Expected behavior

No response

Additional context

No response

victbr avatar Oct 23 '25 12:10 victbr