cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG][CuTe DSL] for loop is wrong if step is negative

Open tridao opened this issue 6 months ago • 3 comments

Describe the bug If the loop is dynamic, with negative step size, the for loop is wrong. In the example below, it does not enter the for loop at all.

cc: @thakkarV

Steps/Code to reproduce bug

import cutlass
import cutlass.cute as cute


@cute.kernel
def kernel(end: cutlass.Int32):
    tidx, _, _ = cute.arch.thread_idx()
    if tidx == 0:
        cute.printf("Before for loop")
    # For loop is correct with positive step, but wrong with negative step
    # for i in range(end, end + 10, 1):
    for i in range(end, 0, -1):
        if tidx == 0:
            cute.printf("i = {}", i)


@cute.jit
def loop_neg_step():
    cutlass.cuda.initialize_cuda_context()
    kernel(cutlass.Int32(10)).launch(
        grid=(1, 1, 1),
        block=(32, 1, 1),
    )

if __name__ == "__main__":
    loop_neg_step()

tridao avatar May 23 '25 16:05 tridao

Thanks for reporting this. The problem is MLIR scf.for requires step to be a zero or positive value. Will try to address this soon.

As a workaround, you can use while for a negative step loop

i = end
while i >= 0:
    if tidx == 0:
        cute.printf("i = {}", i)
    i -= 1

Should work as expected

anakinxc avatar May 24 '25 00:05 anakinxc

@anakinxc let's discuss this internally. For an external user I would just expect for loops with negative increments to work just fine. the DSL should abstract MLIR warts from the user.

thakkarV avatar May 25 '25 19:05 thakkarV

Thanks for trying out the DSL! I appreciate the report — this is indeed a bug, and @anakinxc has already identified the root cause. This program should work, and we’ll work on fixing it.

fwiw, if the loop is not dynamic, it should behave correctly. You can use range_constexpr, but note that the loop bounds must also be constexpr (aka python values).

For example:

end = 10
for i in cutlass.range_constexpr(end, 0, -1):
    if tidx == 0:
        cute.printf("i = {}", i)

Thanks again for the feedback!

grypp avatar May 26 '25 07:05 grypp

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jun 25 '25 23:06 github-actions[bot]

To close the loop, this is fixed in the 4.1. Feel free to let us know if you see any other issue :-)

brandon-yujie-sun avatar Jul 04 '25 02:07 brandon-yujie-sun

Great, thank you so much! This bug is fixed in 4.1.

tridao avatar Jul 04 '25 03:07 tridao