triton icon indicating copy to clipboard operation
triton copied to clipboard

Custom tensor roll kernel produces wrong results for variations of block size and tl.constexpr

Open Marks101 opened this issue 1 year ago • 24 comments

Dear triton team,

I am currently debugging an issue with a kernel that is supposed to replace torch.roll followed by zeroing out the first row of a 2D matrix. This is the code I have:

import torch

import triton
import triton.language as tl

@triton.jit
def triton_roll_and_zero_first_row_kernel(in_ptr, out_ptr, NUM_ROWS: tl.constexpr, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)
    row_idx = xindex // 2
    col_idx = xindex % 2
    rolled_row = (row_idx + 1) % NUM_ROWS
    rolled_xindex = 2 * rolled_row + col_idx
    result = tl.load(in_ptr + rolled_xindex)
    result = tl.where(row_idx == 0, 0.0, result)
    tl.store(out_ptr + xindex, result)

def triton_roll_and_zero_first_row(x):
    assert x.size(1) == 2
    y = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(inp.numel(), meta['XBLOCK']), )
    triton_roll_and_zero_first_row_kernel[grid](x, y, NUM_ROWS=x.size(0), XBLOCK=256)
    return y

def roll_and_zero_first_row(x):
    x = torch.roll(x, -1, 0)
    x[0].fill_(0.0)
    return x

if __name__ == "__main__":
    inp = torch.rand(1024, 2, device="cuda")

    out_eager = roll_and_zero_first_row(inp)
    out_triton = triton_roll_and_zero_first_row(inp)

    print("eager", out_eager)
    print("triton", out_triton)
    assert torch.all(out_eager == out_triton)
    print("PASSED")

In this form the assert does not pass on my system. It passes if I either set XBLOCK=128 or if I remove tl.constexpr from NUM_ROWS. This strikes me as odd. Could you please help me to understand this behaviour?

The blueprint for this kernel was actually produced applying torch.compile to roll_and_zero_first_row and it is affected by this issue as well. I just renamed a few variables and reordered code to make it human-readable. If you can confirm this behavior I would open an issue at pytorch as well.

I am currently on triton-nightly==2.1.0.post20240108192258, torch==2.1.2 and CUDA 12.1

Marks101 avatar Jan 16 '24 07:01 Marks101

fwiw - I tried stepping through the kernel with interpreter mode, and noticed that the repro above passes when I use TRITON_INTERPRET=1 (but fails without interpreter mode)

bdhirsh avatar Jan 19 '24 03:01 bdhirsh

I am able to repro this on main branch. Comparing between XBLOCK=128 and XBLOCK=256, ttir/ttgir look reasonable, the only difference in ttgir is sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4] vs sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4]

The differences in llir also look reasonable. One difference in ptx is that for XBLOCK=128, there is extra logic on ctaid.x: bfe.u32 %r8, %r1, 24, 1;

It is still not clear to me why XBLOCK=256 doesn't work.

Attaching the ptx files. b256-ptx.txt b128-ptx.txt

manman-ren avatar Jan 22 '24 04:01 manman-ren

If we "export DISABLE_LLVM_OPT=1", the test case will pass. So it is related to llvm optimizations.

manman-ren avatar Jan 22 '24 18:01 manman-ren

If we disable InstCombine in optimize_module, it will pass. But I haven't figured out why yet. It is not obvious to me by looking at the differences in ptx: disable-opt against with-opt disable-opt-llir.txt disable-opt-ptx.txt with-opt-llir.txt with-opt-ptx.txt

@@ -28,8 +26,8 @@
 $L__tmp0:
 	.loc	1 9 36
 	mov.u32 	%r6, %tid.x;
-	and.b32  	%r7, %r6, 127;
-	shl.b32 	%r8, %r7, 1;
+	shl.b32 	%r7, %r6, 1;
+	and.b32  	%r8, %r7, 254;
 	.loc	1 8 28
 	// begin inline asm
 	mov.u32 %r1, %ctaid.x;
@@ -60,19 +58,14 @@
 	mov.u32 %r3, 0x0;
 	@%p1 ld.global.v2.b32 { %r2, %r3 }, [ %rd1 + 0 ];
 	// end inline asm
-	mov.b32 	%f1, %r2;
-	mov.b32 	%f2, %r3;
 	.loc	1 15 33
-	setp.eq.s32 	%p3, %r11, 0;
-	.loc	1 15 41
-	selp.f32 	%f3, 0f00000000, %f1, %p3;
-	selp.f32 	%f4, 0f00000000, %f2, %p3;
+	setp.eq.s32 	%p3, %r10, 0;
 	.loc	1 16 23
 	mul.wide.s32 	%rd6, %r10, 4;
 	add.s64 	%rd2, %rd4, %rd6;
 	.loc	1 16 31
-	mov.b32 	%r4, %f3;
-	mov.b32 	%r5, %f4;
+	selp.b32 	%r4, 0, %r2, %p3;
+	selp.b32 	%r5, 0, %r3, %p3;
 	// begin inline asm
 	@%p1 st.global.v2.b32 [ %rd2 + 0 ], { %r4, %r5 };
 	// end inline asm

manman-ren avatar Jan 23 '24 18:01 manman-ren

It looks like the predication is different in the two cases.

Disable opt:

	mov.u32 %r1, %ctaid.x;
	shl.b32 	%r9, %r1, 8;
	or.b32  	%r10, %r9, %r8;
	shr.s32 	%r11, %r10, 1;
        setp.eq.s32 	%p3, %r11, 0;

With opt:

	mov.u32 %r1, %ctaid.x;
	shl.b32 	%r9, %r1, 8;
	or.b32  	%r10, %r9, %r8;
        setp.eq.s32 	%p3, %r10, 0;

htyu avatar Jan 23 '24 19:01 htyu

Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?

manman-ren avatar Jan 23 '24 19:01 manman-ren

Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?

I see, it could make a difference if r10 has value of 1?

htyu avatar Jan 23 '24 20:01 htyu

If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)

jlebar avatar Jan 23 '24 20:01 jlebar

If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)

That is possible. I can try it and see if it makes a difference. Mine is currently cuda-12.1.

manman-ren avatar Jan 23 '24 20:01 manman-ren

Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?

I see, it could make a difference if r10 has value of 1? Yes if r10 can be 1.

` mov.u32 %r6, %tid.x; shl.b32 %r7, %r6, 1; and.b32 %r8, %r7, 254;

mov.u32 %r1, %ctaid.x;

shl.b32 	%r9, %r1, 8;
or.b32  	%r10, %r9, %r8;`

%r8 should be a multiple of 2 and %r9 should be a multiple of 2.

manman-ren avatar Jan 23 '24 20:01 manman-ren

If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)

That is possible. I can try it and see if it makes a difference. Mine is currently cuda-12.1.

you can also disable optimizations in ptxas --opt-level 0

ThomasRaoux avatar Jan 23 '24 20:01 ThomasRaoux

Disabling optimization for ptxas also fixed the problem. def-sass.txt no-ptxas-opt-sass.txt

Patch to enable debugging: https://github.com/openai/triton/pull/2995

manman-ren avatar Jan 23 '24 22:01 manman-ren

Disabling optimization for ptxas also fixed the problem. def-sass.txt no-ptxas-opt-sass.txt

Patch to enable debugging: #2995

Thanks for giving it a shot. So the problem went away with LLVM opt on but PTX opt off?

htyu avatar Jan 23 '24 22:01 htyu

A quick summary: 1> default config (with llvm optimizations, with ptxas optimizations) BLOCKSIZE of 128: works; BLOCKSIZE of 256: fails 2> disable llvm optimizations via DISABLE_LLVM_OPT, with ptxas optimizations BLOCKSIZE of 256: works 2a> with llvm optimizations O0, disable InstCombinePass here https://github.com/openai/triton/blob/e6e5d5468e92ed3af3e40babdd55c3da506ab01f/python/src/llvm.cc#L190 BLOCKSIZE of 256: works 3> enable llvm optimizations without ptxas optimizations BLOCKSIZE of 256: works I haven't looked at the differences in sass for item 3 yet.

manman-ren avatar Jan 24 '24 18:01 manman-ren

Attempted to debug this with cuda-gdb, but it turned out that the issue is likely in ptxas, the updating of predicate seems to be gone in sass (i.e P0 is not set but it is used).

          IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28]
          S2R R0, SR_TID.X
          IMAD.MOV.U32 R6, RZ, RZ, 0x4
          ULDC.64 UR4, c[0x0][0x118]
          S2R R5, SR_CTAID.X
          IMAD.SHL.U32 R0, R0, 0x2, RZ
          LOP3.LUT R0, R0, 0xfe, RZ, 0xc0, !PT
          PRMT R4, R0, 0x6540, R5
          LEA.HI.SX32 R2, R4, 0x1, 0x1f
          SHF.R.S32.HI R3, RZ, 0x1f, R2
          LEA.HI R3, R3, R2, RZ, 0xa
          LOP3.LUT R3, R3, 0x7ffffc00, RZ, 0xc0, !PT
          IMAD.IADD R3, R2, 0x1, -R3
          IMAD.SHL.U32 R3, R3, 0x2, RZ
          IMAD.WIDE R2, R3, R6, c[0x0][0x160]
          LDG.E.64 R2, [R2.64]
          PRMT RZ, R0, 0x6540, R5
          IMAD.WIDE R4, R4, R6, c[0x0][0x168]
          SEL R6, R2, RZ, P0
          SEL R7, R3, RZ, P0
          STG.E.64 [R4.64], R6

Setting of predicate in ptx:

	setp.eq.s32 	%p3, %r10, 0;
	.loc	1 16 23
	mul.wide.s32 	%rd6, %r10, 4;
	add.s64 	%rd2, %rd4, %rd6;
	.loc	1 16 31
	selp.b32 	%r4, 0, %r2, %p3;
	selp.b32 	%r5, 0, %r3, %p3;

manman-ren avatar Jan 26 '24 17:01 manman-ren

Attempted to debug this with cuda-gdb, but it turned out that the issue is likely in ptxas, the updating of predicate seems to be gone in sass (i.e P0 is not set but it is used).

          IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28]
          S2R R0, SR_TID.X
          IMAD.MOV.U32 R6, RZ, RZ, 0x4
          ULDC.64 UR4, c[0x0][0x118]
          S2R R5, SR_CTAID.X
          IMAD.SHL.U32 R0, R0, 0x2, RZ
          LOP3.LUT R0, R0, 0xfe, RZ, 0xc0, !PT
          PRMT R4, R0, 0x6540, R5
          LEA.HI.SX32 R2, R4, 0x1, 0x1f
          SHF.R.S32.HI R3, RZ, 0x1f, R2
          LEA.HI R3, R3, R2, RZ, 0xa
          LOP3.LUT R3, R3, 0x7ffffc00, RZ, 0xc0, !PT
          IMAD.IADD R3, R2, 0x1, -R3
          IMAD.SHL.U32 R3, R3, 0x2, RZ
          IMAD.WIDE R2, R3, R6, c[0x0][0x160]
          LDG.E.64 R2, [R2.64]
          PRMT RZ, R0, 0x6540, R5
          IMAD.WIDE R4, R4, R6, c[0x0][0x168]
          SEL R6, R2, RZ, P0
          SEL R7, R3, RZ, P0
          STG.E.64 [R4.64], R6

Setting of predicate in ptx:

	setp.eq.s32 	%p3, %r10, 0;
	.loc	1 16 23
	mul.wide.s32 	%rd6, %r10, 4;
	add.s64 	%rd2, %rd4, %rd6;
	.loc	1 16 31
	selp.b32 	%r4, 0, %r2, %p3;
	selp.b32 	%r5, 0, %r3, %p3;

interesting findings! Ptxas bugs are always very hard to track :(

Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)

You can also run ptxas with different versions and override the cubin generated if it is simpler

ThomasRaoux avatar Jan 26 '24 17:01 ThomasRaoux

You can also run ptxas with different versions and override the cubin generated if it is simpler

This will be super useful if we can swap the cubin and run it. How do we do it?

Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)

I haven't. Let me ask around how to get the latest.

manman-ren avatar Jan 26 '24 17:01 manman-ren

You can also run ptxas with different versions and override the cubin generated if it is simpler

This will be super useful if we can swap the cubin and run it. How do we do it?

It's a bit of an experimental debug feature but here: https://github.com/openai/triton/blob/main/python/triton/compiler/compiler.py#L218 The flow I usually use:

  1. export TRITON_KERNEL_DUMP=1 and clear the cache
  2. run the test: python ....
  3. Find the kernel you are interested in ~/.triton/dump/kernel_hash/
  4. copy the new cubin in ~/.triton/override/ with the same exact path as above ~/.triton/override/kernel_hash/kernel_name.cubin
  5. clear the cache and run TRITON_KERNEL_OVERRIDE=1 python....
  6. Check in the console for the message Overriding kernel with file kernel_name.cubin to make sure it happened

Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)

I haven't. Let me ask around how to get the latest.

You can see an example here: https://github.com/ThomasRaoux/triton/commit/edb74eb5bda10c7019f3717acb5d0c97eb2a411d

ThomasRaoux avatar Jan 26 '24 18:01 ThomasRaoux

Tried with cuda_12.3.r12.3 ptxas, it generates the same code. I am not familiar with SASS, maybe it is possible that P0 is set implicitly by some instructions? But the incorrect outputs are all 0.0, which looks like the predicate is always choosing 0.0.

manman-ren avatar Jan 26 '24 21:01 manman-ren

If it's very possible that there's a ptxas bug, we could prepare a reproducer and sent it to nvidia compiler folks. They usually respond quickly based on my experience.

Jokeren avatar Jan 27 '24 02:01 Jokeren

@Jokeren Yes here is the repro and a README, let me know how to send it over to NVidia compiler folks. repro-ptx.txt README.txt

manman-ren avatar Jan 29 '24 17:01 manman-ren

Filed as https://developer.nvidia.com/bugs/4474599

manman-ren avatar Jan 29 '24 18:01 manman-ren

Reply from NVidia: This is a ptxas optimization bug that we've already fixed internally. It will be available in an update release of CUDA 12.4 soon. The optimization was incorrectly transforming a LOP3 (which supports predicate output) into a PRMT (which doesn't support predicate out).

manman-ren avatar Jan 30 '24 17:01 manman-ren

Thanks to everyone involved for figuring this out 🙌

Marks101 avatar Jan 31 '24 10:01 Marks101

CUDA 12.4 is out. Has anyone tested it?

jpilaul avatar Mar 18 '24 16:03 jpilaul

Still running into the same issues with CUDA 12.4, torch=2.5.0.dev20240709+cu124, triton=2.3.1

windsornguyen avatar Jul 09 '24 23:07 windsornguyen