[BUG] for loop and while loop have different behavior
Which component has the problem?
CuTe DSL
Bug Report
Steps/Code to reproduce bug
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.jit
def tile_scheduler_get_next(cur_id, group_size, group_size_m, max_m):
group_id = cur_id // group_size
id_in_group = cur_id % group_size
if cur_id == 78:
cute.printf("cur_id={} group_size={} calc mod= {} {}", cur_id, group_size, 78%128, cur_id%group_size)
return group_id, id_in_group
@cute.kernel
def demo_kernel():
tx, _, _ = cute.arch.thread_idx()
cute.printf("for loop")
for i in range(tx, 512, 78):
bm, bn = tile_scheduler_get_next(i, 128, 8, 32)
i = tx
cute.printf("while loop")
while i < 512:
bm, bn = tile_scheduler_get_next(i, 128, 8, 32)
i += 78
@cute.jit
def demof():
demo_kernel().launch(
grid=[1, 1, 1],
block=[1, 1, 1],
)
torch.ones(3, 4, device="cuda")
demof()
output:
for loop
cur_id=78 group_size=128 calc mod= 78 1
while loop
cur_id=78 group_size=128 calc mod= 78 78
This code is used for tile swizzle in GEMM. But I found when doing the mod calclation, the result become strange, and it seems to be related to function calls.
If I use the while loop that works in the same way as for loop, the problem gone.
Expected behavior cur_id%group_size should be 78
Environment details (please complete the following information):
- Driver Version: 550.127.08
- nvcc 12.9
- nvidia-cutlass-dsl 4.3.0.dev0 or 4.2.1
- torch 2.8.0
Hi @HarryWu99
Thanks for reporting, will fix in next release.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
This is fixed in the latest wheel, @HarryWu99 please let us know if this is still an issue.