cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] for loop and while loop have different behavior

Open HarryWu99 opened this issue 2 months ago • 2 comments

Which component has the problem?

CuTe DSL

Bug Report

Steps/Code to reproduce bug

import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.jit
def tile_scheduler_get_next(cur_id, group_size, group_size_m, max_m):
    group_id = cur_id // group_size
    id_in_group = cur_id % group_size
    if cur_id == 78:
        cute.printf("cur_id={} group_size={} calc mod= {} {}", cur_id, group_size, 78%128, cur_id%group_size)
    return group_id, id_in_group

@cute.kernel
def demo_kernel():
    tx, _, _ = cute.arch.thread_idx()
    cute.printf("for loop")
    for i in range(tx, 512, 78):
        bm, bn = tile_scheduler_get_next(i, 128, 8, 32)
    i = tx
    cute.printf("while loop")
    while i < 512:
        bm, bn = tile_scheduler_get_next(i, 128, 8, 32)
        i += 78

@cute.jit
def demof():
    demo_kernel().launch(
        grid=[1, 1, 1],
        block=[1, 1, 1],
    )

torch.ones(3, 4, device="cuda")
demof()

output:

for loop
cur_id=78 group_size=128 calc mod= 78 1
while loop
cur_id=78 group_size=128 calc mod= 78 78

This code is used for tile swizzle in GEMM. But I found when doing the mod calclation, the result become strange, and it seems to be related to function calls.

If I use the while loop that works in the same way as for loop, the problem gone.

Expected behavior cur_id%group_size should be 78

Environment details (please complete the following information):

  • Driver Version: 550.127.08
  • nvcc 12.9
  • nvidia-cutlass-dsl 4.3.0.dev0 or 4.2.1
  • torch 2.8.0

HarryWu99 avatar Oct 26 '25 03:10 HarryWu99

Hi @HarryWu99

Thanks for reporting, will fix in next release.

anakinxc avatar Oct 27 '25 00:10 anakinxc

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Nov 26 '25 01:11 github-actions[bot]

This is fixed in the latest wheel, @HarryWu99 please let us know if this is still an issue.

brandon-yujie-sun avatar Dec 12 '25 07:12 brandon-yujie-sun