[BUG] CuteDSL compiler hangs when calling copy with mismatched predicates
The compiler hangs with this code, where we call copy with the wrong copy atom / predicates. The compiler still should error out instead of hanging.
Steps/Code to reproduce bug
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.kernel
def hang_kernel(
gX: cute.Tensor,
cX: cute.Tensor, # coordinate tensor
shape: cute.Shape,
tv_layout: cute.Layout,
tiler_mn: cute.Shape,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
blkCrd = cX[(None, None), 0]
blkX = gX[(None, None), 0]
copy_atom_load_bf16 = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.BFloat16)
copy_atom_load_fp32 = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32)
thr_copy_bf16 = cute.make_tiled_copy(copy_atom_load_bf16, tv_layout, tiler_mn).get_slice(tidx)
thr_copy_fp32 = cute.make_tiled_copy(copy_atom_load_fp32, tv_layout, tiler_mn).get_slice(tidx)
thrFp32 = thr_copy_fp32.partition_S(blkX)
thrBf16 = thr_copy_bf16.partition_S(blkX)
frgX = cute.make_fragment_like(thrFp32)
thrCrd = thr_copy_bf16.partition_S(blkCrd)
frgPred = cute.make_fragment(frgX.shape, cutlass.Boolean)
for i in range(cute.size(frgPred)):
frgPred[i] = cute.elem_less(thrCrd[i], shape)
cute.copy(copy_atom_load_bf16, thrFp32, frgX, pred=frgPred)
# Does not hang if we copy without the predicate, or we call copy with copy_atom_load_fp32
@cute.jit
def hang(mX):
N = mX.shape[-1]
warpsize = 32
num_blocks_N = cute.ceil_div(N, warpsize)
tiler_mn = (1, N)
tv_layout = cute.make_layout(
((warpsize, 1), (1, num_blocks_N)),
stride=((1, 1), (1, warpsize))
)
print(f"[DSL INFO] Input Tensors:")
print(f"[DSL INFO] mX = {mX.type}")
print(f"[DSL INFO] Tiling Parameters:")
print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block")
print(f"[DSL INFO] tv_layout = {tv_layout}")
mX_expanded_layout = cute.prepend(mX.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
mX_expanded = cute.make_tensor(mX.iterator, mX_expanded_layout)
gX = cute.zipped_divide(mX_expanded, tiler_mn) # ((TileM,TileN),(RestM,RestN))
print(f"[DSL INFO] Tiled Tensors:")
print(f"[DSL INFO] gX = {gX.type}")
shape = (4, N)
idX = cute.make_identity_tensor(shape)
cX = cute.zipped_divide(idX, tiler=tiler_mn)
print(f"[DSL INFO] coord tensor = {cX.type}")
hang_kernel(gX, cX, shape, tv_layout, tiler_mn).launch(
grid=[1, 1, 1],
block=[cute.size(tv_layout, mode=[0]), 1, 1],
)
def run(N):
if not torch.cuda.is_available():
raise RuntimeError(f"Ampere GPU is required to run this example!")
print(f"Tensor dimensions: [{N}]")
x = torch.randn(N, device=torch.device("cuda"), dtype=torch.float32)
print(f"Input tensor shapes:")
print(f"x: {x.shape}, dtype: {x.dtype}")
x_tensor = from_dlpack(x, assumed_align=16)
print("Compiling kernel with cute.compile ...")
compiled_func = cute.compile(hang, x_tensor)
print("Done compiling...")
if __name__ == "__main__":
run(1024)
Cc @thakkarV
thanks for the report. will take a look.
thanks for the report. Quick answer is copy atom copy type is not the same as src/dst data type. Compiler handles the case correctly when pred is not provided. With pred, compiler doesn't handle it correctly. Just one line of code change to fix. Waiting for pip wheel push ASAP.
cute.copy(copy_atom_load_bf16, thrFp32, frgX, pred=frgPred)
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.