tilelang
tilelang copied to clipboard
[BUG] index of T.Tensor cannot be int64
Required prerequisites
- [x] I have read the documentation https://tilelang.com.
- [x] I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6
System information
PyTorch version: 2.7.0+cu128
Problem description
When modifying a tensor, if the passed index value is of type int64, it will result in a compilation failure and throw the error: "TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32" Even if the index range could genuinely require int64. In such a scenario, if the index value is forcibly cast to int32, it will instead report the error: "tvm.error.InternalError: Check failed: arg.dtype() == value.dtype() (int32 vs. int64) : "
Reproducible example code
import torch
import tilelang
import tilelang.language as T
@tilelang.jit
def set_cache_kernel(
S,
B,
D,
dtype="float32",
):
@T.prim_func
def main(
pos: T.Tensor([S,], "int64"), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
# pos: T.Tensor([S,], "int32"), # type: ignore `tvm.error.InternalError: Check failed: arg.dtype() == value.dtype() (int32 vs. int64) :`
value: T.Tensor([S, D], dtype), # type: ignore
cache: T.Tensor([B, D], dtype), # type: ignore
):
with T.Kernel(S, threads=128) as bx:
slot = pos[bx]
for i in T.Parallel(D):
cache[slot, i] = value[bx, i]
return main
def set_cache():
B = 32768*64*32*64
D = 2
cache = torch.rand((B, D), device="cuda", dtype=torch.float32)
S = 10
kernel = set_cache_kernel(S, B, D)
if __name__ == "__main__":
set_cache()
Traceback
Expected behavior
No response
Additional context
No response