triton
triton copied to clipboard
Custom tensor roll kernel produces wrong results for variations of block size and tl.constexpr
Dear triton team,
I am currently debugging an issue with a kernel that is supposed to replace torch.roll followed by zeroing out the first row of a 2D matrix. This is the code I have:
import torch
import triton
import triton.language as tl
@triton.jit
def triton_roll_and_zero_first_row_kernel(in_ptr, out_ptr, NUM_ROWS: tl.constexpr, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)
row_idx = xindex // 2
col_idx = xindex % 2
rolled_row = (row_idx + 1) % NUM_ROWS
rolled_xindex = 2 * rolled_row + col_idx
result = tl.load(in_ptr + rolled_xindex)
result = tl.where(row_idx == 0, 0.0, result)
tl.store(out_ptr + xindex, result)
def triton_roll_and_zero_first_row(x):
assert x.size(1) == 2
y = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(inp.numel(), meta['XBLOCK']), )
triton_roll_and_zero_first_row_kernel[grid](x, y, NUM_ROWS=x.size(0), XBLOCK=256)
return y
def roll_and_zero_first_row(x):
x = torch.roll(x, -1, 0)
x[0].fill_(0.0)
return x
if __name__ == "__main__":
inp = torch.rand(1024, 2, device="cuda")
out_eager = roll_and_zero_first_row(inp)
out_triton = triton_roll_and_zero_first_row(inp)
print("eager", out_eager)
print("triton", out_triton)
assert torch.all(out_eager == out_triton)
print("PASSED")
In this form the assert does not pass on my system. It passes if I either set XBLOCK=128
or if I remove tl.constexpr
from NUM_ROWS
. This strikes me as odd. Could you please help me to understand this behaviour?
The blueprint for this kernel was actually produced applying torch.compile
to roll_and_zero_first_row
and it is affected by this issue as well. I just renamed a few variables and reordered code to make it human-readable. If you can confirm this behavior I would open an issue at pytorch as well.
I am currently on triton-nightly==2.1.0.post20240108192258, torch==2.1.2 and CUDA 12.1
fwiw - I tried stepping through the kernel with interpreter mode, and noticed that the repro above passes when I use TRITON_INTERPRET=1
(but fails without interpreter mode)
I am able to repro this on main branch. Comparing between XBLOCK=128 and XBLOCK=256, ttir/ttgir look reasonable, the only difference in ttgir is sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4] vs sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4]
The differences in llir also look reasonable. One difference in ptx is that for XBLOCK=128, there is extra logic on ctaid.x: bfe.u32 %r8, %r1, 24, 1;
It is still not clear to me why XBLOCK=256 doesn't work.
Attaching the ptx files. b256-ptx.txt b128-ptx.txt
If we "export DISABLE_LLVM_OPT=1", the test case will pass. So it is related to llvm optimizations.
If we disable InstCombine in optimize_module, it will pass. But I haven't figured out why yet. It is not obvious to me by looking at the differences in ptx: disable-opt against with-opt disable-opt-llir.txt disable-opt-ptx.txt with-opt-llir.txt with-opt-ptx.txt
@@ -28,8 +26,8 @@
$L__tmp0:
.loc 1 9 36
mov.u32 %r6, %tid.x;
- and.b32 %r7, %r6, 127;
- shl.b32 %r8, %r7, 1;
+ shl.b32 %r7, %r6, 1;
+ and.b32 %r8, %r7, 254;
.loc 1 8 28
// begin inline asm
mov.u32 %r1, %ctaid.x;
@@ -60,19 +58,14 @@
mov.u32 %r3, 0x0;
@%p1 ld.global.v2.b32 { %r2, %r3 }, [ %rd1 + 0 ];
// end inline asm
- mov.b32 %f1, %r2;
- mov.b32 %f2, %r3;
.loc 1 15 33
- setp.eq.s32 %p3, %r11, 0;
- .loc 1 15 41
- selp.f32 %f3, 0f00000000, %f1, %p3;
- selp.f32 %f4, 0f00000000, %f2, %p3;
+ setp.eq.s32 %p3, %r10, 0;
.loc 1 16 23
mul.wide.s32 %rd6, %r10, 4;
add.s64 %rd2, %rd4, %rd6;
.loc 1 16 31
- mov.b32 %r4, %f3;
- mov.b32 %r5, %f4;
+ selp.b32 %r4, 0, %r2, %p3;
+ selp.b32 %r5, 0, %r3, %p3;
// begin inline asm
@%p1 st.global.v2.b32 [ %rd2 + 0 ], { %r4, %r5 };
// end inline asm
It looks like the predication is different in the two cases.
Disable opt:
mov.u32 %r1, %ctaid.x;
shl.b32 %r9, %r1, 8;
or.b32 %r10, %r9, %r8;
shr.s32 %r11, %r10, 1;
setp.eq.s32 %p3, %r11, 0;
With opt:
mov.u32 %r1, %ctaid.x;
shl.b32 %r9, %r1, 8;
or.b32 %r10, %r9, %r8;
setp.eq.s32 %p3, %r10, 0;
Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?
Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?
I see, it could make a difference if r10 has value of 1?
If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)
If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)
That is possible. I can try it and see if it makes a difference. Mine is currently cuda-12.1.
Yes, but the problem is that r10 is supposed to be a multiple of 2, so checking it against 0 vs. checking (r10 >> 1) against 0 should be the same?
I see, it could make a difference if r10 has value of 1? Yes if r10 can be 1.
` mov.u32 %r6, %tid.x; shl.b32 %r7, %r6, 1; and.b32 %r8, %r7, 254;
mov.u32 %r1, %ctaid.x;
shl.b32 %r9, %r1, 8;
or.b32 %r10, %r9, %r8;`
%r8 should be a multiple of 2 and %r9 should be a multiple of 2.
If the ptx looks correct in both cases, then this sounds like a ptxas bug. (Have we tried the latest ptxas?)
That is possible. I can try it and see if it makes a difference. Mine is currently cuda-12.1.
you can also disable optimizations in ptxas --opt-level 0
Disabling optimization for ptxas also fixed the problem. def-sass.txt no-ptxas-opt-sass.txt
Patch to enable debugging: https://github.com/openai/triton/pull/2995
Disabling optimization for ptxas also fixed the problem. def-sass.txt no-ptxas-opt-sass.txt
Patch to enable debugging: #2995
Thanks for giving it a shot. So the problem went away with LLVM opt on but PTX opt off?
A quick summary: 1> default config (with llvm optimizations, with ptxas optimizations) BLOCKSIZE of 128: works; BLOCKSIZE of 256: fails 2> disable llvm optimizations via DISABLE_LLVM_OPT, with ptxas optimizations BLOCKSIZE of 256: works 2a> with llvm optimizations O0, disable InstCombinePass here https://github.com/openai/triton/blob/e6e5d5468e92ed3af3e40babdd55c3da506ab01f/python/src/llvm.cc#L190 BLOCKSIZE of 256: works 3> enable llvm optimizations without ptxas optimizations BLOCKSIZE of 256: works I haven't looked at the differences in sass for item 3 yet.
Attempted to debug this with cuda-gdb, but it turned out that the issue is likely in ptxas, the updating of predicate seems to be gone in sass (i.e P0 is not set but it is used).
IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28]
S2R R0, SR_TID.X
IMAD.MOV.U32 R6, RZ, RZ, 0x4
ULDC.64 UR4, c[0x0][0x118]
S2R R5, SR_CTAID.X
IMAD.SHL.U32 R0, R0, 0x2, RZ
LOP3.LUT R0, R0, 0xfe, RZ, 0xc0, !PT
PRMT R4, R0, 0x6540, R5
LEA.HI.SX32 R2, R4, 0x1, 0x1f
SHF.R.S32.HI R3, RZ, 0x1f, R2
LEA.HI R3, R3, R2, RZ, 0xa
LOP3.LUT R3, R3, 0x7ffffc00, RZ, 0xc0, !PT
IMAD.IADD R3, R2, 0x1, -R3
IMAD.SHL.U32 R3, R3, 0x2, RZ
IMAD.WIDE R2, R3, R6, c[0x0][0x160]
LDG.E.64 R2, [R2.64]
PRMT RZ, R0, 0x6540, R5
IMAD.WIDE R4, R4, R6, c[0x0][0x168]
SEL R6, R2, RZ, P0
SEL R7, R3, RZ, P0
STG.E.64 [R4.64], R6
Setting of predicate in ptx:
setp.eq.s32 %p3, %r10, 0;
.loc 1 16 23
mul.wide.s32 %rd6, %r10, 4;
add.s64 %rd2, %rd4, %rd6;
.loc 1 16 31
selp.b32 %r4, 0, %r2, %p3;
selp.b32 %r5, 0, %r3, %p3;
Attempted to debug this with cuda-gdb, but it turned out that the issue is likely in ptxas, the updating of predicate seems to be gone in sass (i.e P0 is not set but it is used).
IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] S2R R0, SR_TID.X IMAD.MOV.U32 R6, RZ, RZ, 0x4 ULDC.64 UR4, c[0x0][0x118] S2R R5, SR_CTAID.X IMAD.SHL.U32 R0, R0, 0x2, RZ LOP3.LUT R0, R0, 0xfe, RZ, 0xc0, !PT PRMT R4, R0, 0x6540, R5 LEA.HI.SX32 R2, R4, 0x1, 0x1f SHF.R.S32.HI R3, RZ, 0x1f, R2 LEA.HI R3, R3, R2, RZ, 0xa LOP3.LUT R3, R3, 0x7ffffc00, RZ, 0xc0, !PT IMAD.IADD R3, R2, 0x1, -R3 IMAD.SHL.U32 R3, R3, 0x2, RZ IMAD.WIDE R2, R3, R6, c[0x0][0x160] LDG.E.64 R2, [R2.64] PRMT RZ, R0, 0x6540, R5 IMAD.WIDE R4, R4, R6, c[0x0][0x168] SEL R6, R2, RZ, P0 SEL R7, R3, RZ, P0 STG.E.64 [R4.64], R6
Setting of predicate in ptx:
setp.eq.s32 %p3, %r10, 0; .loc 1 16 23 mul.wide.s32 %rd6, %r10, 4; add.s64 %rd2, %rd4, %rd6; .loc 1 16 31 selp.b32 %r4, 0, %r2, %p3; selp.b32 %r5, 0, %r3, %p3;
interesting findings! Ptxas bugs are always very hard to track :(
Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)
You can also run ptxas with different versions and override the cubin generated if it is simpler
You can also run ptxas with different versions and override the cubin generated if it is simpler
This will be super useful if we can swap the cubin and run it. How do we do it?
Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)
I haven't. Let me ask around how to get the latest.
You can also run ptxas with different versions and override the cubin generated if it is simpler
This will be super useful if we can swap the cubin and run it. How do we do it?
It's a bit of an experimental debug feature but here: https://github.com/openai/triton/blob/main/python/triton/compiler/compiler.py#L218 The flow I usually use:
-
export TRITON_KERNEL_DUMP=1
and clear the cache - run the test:
python ....
- Find the kernel you are interested in
~/.triton/dump/kernel_hash/
- copy the new cubin in
~/.triton/override/
with the same exact path as above~/.triton/override/kernel_hash/kernel_name.cubin
- clear the cache and run
TRITON_KERNEL_OVERRIDE=1 python....
- Check in the console for the message
Overriding kernel with file kernel_name.cubin
to make sure it happened
Have you tried using the latest ptxas as suggested by @jlebar? (we haven't upgraded to the latest one yet)
I haven't. Let me ask around how to get the latest.
You can see an example here: https://github.com/ThomasRaoux/triton/commit/edb74eb5bda10c7019f3717acb5d0c97eb2a411d
Tried with cuda_12.3.r12.3 ptxas, it generates the same code. I am not familiar with SASS, maybe it is possible that P0 is set implicitly by some instructions? But the incorrect outputs are all 0.0, which looks like the predicate is always choosing 0.0.
If it's very possible that there's a ptxas bug, we could prepare a reproducer and sent it to nvidia compiler folks. They usually respond quickly based on my experience.
@Jokeren Yes here is the repro and a README, let me know how to send it over to NVidia compiler folks. repro-ptx.txt README.txt
Filed as https://developer.nvidia.com/bugs/4474599
Reply from NVidia: This is a ptxas optimization bug that we've already fixed internally. It will be available in an update release of CUDA 12.4 soon. The optimization was incorrectly transforming a LOP3 (which supports predicate output) into a PRMT (which doesn't support predicate out).
Thanks to everyone involved for figuring this out 🙌
CUDA 12.4 is out. Has anyone tested it?
Still running into the same issues with CUDA 12.4, torch=2.5.0.dev20240709+cu124, triton=2.3.1