triton
triton copied to clipboard
No `other` branch in tl.load leads to invalid results, even though all other values are masked out
Repro below. Generated ptx looks valid in both cases, with only difference in movs with @!pxx as expected . Happens with fp16, float32 is ok. I'm deliberately setting other to nan to make sure that it's all masked out, results are always non-nan. If other kwarg is not specified, first 2 columns of result are random. I've tried trimming down loads further, but it looks like 3 loads is the minimum to trigger this bug.
import torch
aten = torch.ops.aten
import triton
import triton.language as tl
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x1 = (xindex // 8) % 8
x0 = xindex % 8
x4 = xindex
tmp0 = (-1) + x1
tmp1 = 0
tmp2 = tmp0 >= tmp1
tmp3 = 8
tmp4 = tmp0 < tmp3
tmp5 = tmp2 & tmp4
tmp6 = (-1) + x0
tmp7 = tmp6 >= tmp1
tmp8 = tmp6 < tmp3
tmp9 = tmp7 & tmp8
tmp13 = x0
tmp14 = tmp13 >= tmp1
tmp15 = tmp13 < tmp3
tmp16 = tmp14 & tmp15
tmp21 = 1 + x0
tmp22 = tmp21 >= tmp1
tmp23 = tmp21 < tmp3
tmp24 = tmp22 & tmp23
tmp25 = tmp5 & tmp24
tmp26 = tl.load(in_ptr0 + (-7) + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp25).to(tl.float32)
tmp27 = tl.where(tmp25, tmp26, 0.0)
tmp29 = x1
tmp30 = tmp29 >= tmp1
tmp31 = tmp29 < tmp3
tmp32 = tmp30 & tmp31
tmp33 = tmp32 & tmp9
# no `other` kwarg leads to invalid results
# tmp34 = tl.load(in_ptr0 + (-1) + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp33).to(tl.float32)
tmp34 = tl.load(in_ptr0 + (-1) + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp33, other=float('nan')).to(tl.float32)
tmp35 = tl.where(tmp33, tmp34, 0.0)
tmp36 = tmp35 + tmp27
tmp37 = tmp32 & tmp16
# no `other` kwarg leads to invalid results
tmp38 = tl.load(in_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp37).to(tl.float32)
tmp38 = tl.load(in_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp37, other=float('nan')).to(tl.float32)
tmp39 = tl.where(tmp37, tmp38, 0.0)
tmp40 = tmp39 + tmp36
tl.store(out_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), tmp40, xmask)
def call(arg0_1, arg1_1):
arg0_1_size = arg0_1.size()
s0 = arg0_1_size[0]
buf0 = torch.empty_strided((s0, 2048, 8, 8), (131072, 64, 8, 1), device='cuda', dtype=torch.float16)
kernel0_xnumel = 131072*s0
pgm = kernel[(1,)](arg0_1, buf0, kernel0_xnumel, 1024)
ptx = pgm.asm['ptx']
with open("avgpool.ptx", "w") as f:
for l in ptx:
f.write(l)
print(buf0[0,0])
return (buf0, )
if __name__ == "__main__":
from torchdynamo.testing import rand_strided
from torchinductor.utils import print_performance
arg0_1 = rand_strided((2, 2048, 8, 8), (131072, 64, 8, 1), device='cuda', dtype=torch.float16).fill_(1.)
arg1_1 = rand_strided((192, 2048, 1, 1), (2048, 1, 1, 1), device='cuda', dtype=torch.float16).fill_(1.)
call(arg0_1, arg1_1)
@ngimel Can you try the master branch to see if this problem still exists?
It seems to me a bit different from #745, since you said the problem appears when other is not used. Probably have to take a further look.
Yes, the problem still exists on master. #745 is different, I hit #745 because I'm forced to use other even though I don't think I need to (other values should be masked anyway).
Got you, I'll take a look later.
Hi @ngimel, I have a question about the following two lines:
tmp26 = tl.load(in_ptr0 + (-7) + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp25).to(tl.float32) tmp27 = tl.where(tmp25, tmp26, 0.0)
Is there any false in xmask & tmp25? If so, my understanding is that the values loaded at false locations are undefined. In other words, the random value that stored on the corresponding registers.
Specifying other = 0 should also yield the expected results.
Yeah it's possible that undefined values are loaded, but they are masked out in the following lines so it shouldn't matter whether there's an other specified or not, that's what I demonstrate by specifying other=float('nan') and results aren't nan.
Specifying any other yields expected results, again, because the particular value doesn't matter, it's masked out, but for the same reason not specifying other should work.
[Sorry got my comment accidentally deleted]
In the triton programming model, I think the mask doesn't propagate across multiple instructions. @ptillet can you please confirm?
So for this instruction: tl.where(tmp33, tmp34, 0.0), if tmp33[0] == 1, it will select tmp34[0] which could be an undefined value.
xmask is always True in this case (I launch just 1 program_id, so maximum xindex is 1023, xnumel is much larger). So the same mask, tmp33 is used for read and for tl.where, so tmp34 is never used in places where it's undefined (think of it this way - when I set other=float('nan') if that value was ever used later it would corrupt the output, but it doesn't).
That's interesting, the problem went away by removing xmask.
tmp26 = tl.load(in_ptr0 + (-7) + x4 + tl.zeros([XBLOCK], tl.int32), tmp25).to(tl.float32)
tmp34 = tl.load(in_ptr0 + (-1) + x4 + tl.zeros([XBLOCK], tl.int32), tmp33).to(tl.float32)
tmp38 = tl.load(in_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), tmp37).to(tl.float32)
Yeah I noticed that removing xmask changes results (I don't remember if it always fixed the problem), but that was very surprising too, as xmask is always true here.
Hmmm 🤔 I think there are multiple options:
xmaskisn't always true for some reason- removing xmask changes vectorization behavior. This kernel is fairly simple so I wouldn't really expect bugs there, but seeing the PTX wouldn't hurt :p
- maybe a weird issue with &
We found there's an issue with the generated ptx code. Register %rs9 is not used after it's defined. However, the problem doesn't exist in the generated llvm IR. Thus, we suspected this might be an issue of the PTX code gen logic in llvm. Since triton uses a relatively old llvm release (11.0), I tried to update llvm to recent releases but didn't get the problem solved with llvm-12.0 or llvm-13.0. I'm not able to make llvm >= 14.0 work because there are many API changes not compatible with the current triton backend. I think we could look back on this issue after MLIR integration is done since the triton-mlir branch uses a relatively newer llvm.
@%p9 ld.global.b16 {%rs9}, [ %rd12 + 0];
@%p10 ld.global.b16 {%rs10}, [ %rd12 + 2];
mov.b16 %h9, %rs10;
@%p11 ld.global.b16 {%rs11}, [ %rd12 + 4];
mov.b16 %h10, %rs11;
@%p12 ld.global.b16 {%rs12}, [ %rd12 + 6];
mov.b16 %h11, %rs12;
@%p13 ld.global.b16 {%rs13}, [ %rd12 + 8];
mov.b16 %h12, %rs13;
@%p14 ld.global.b16 {%rs14}, [ %rd12 + 10];
mov.b16 %h13, %rs14;
@%p15 ld.global.b16 {%rs15}, [ %rd12 + 12];
mov.b16 %h14, %rs15;
@%p16 ld.global.b16 {%rs16}, [ %rd12 + 14];
mov.b16 %h15, %rs16;
A new example:
w/ and w/o other=0.0 yield different results.
from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
triton__5 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
@reduction(size_hints=[8192, 256],
reduction_hint=ReductionHint.INNER,
filename=__file__,
meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())]})
@triton.jit
def triton__5(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp4 = (tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")).to(tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (197*x0)), rmask & xmask).to(tl.float32)
tmp1 = 0.125
tmp2 = tmp0 * tmp1
tmp3 = tmp2.to(tl.float32)
tmp5 = _tmp4 < tmp3
_tmp4 = tl.where(rmask & xmask & (_tmp4 < tmp3), tmp3, _tmp4)
tl.store(out_ptr2 + (r1 + (197*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
''')
async_compile.wait(globals())
del async_compile
torch.manual_seed(123)
stream0 = get_cuda_stream(0)
s0 = 4
buf11 = torch.randn((6*s0, 197, 197), device='cuda', dtype=torch.float16)
buf14 = torch.randn((s0, 6, 197, 197), device='cuda', dtype=torch.float16)
triton__5_xnumel = 1182*s0
triton__5.run(buf11, buf14, triton__5_xnumel, 197, grid=grid(triton__5_xnumel), stream=stream0)
print(buf14.mean())
And again removing (_tmp4 < tmp3) from tl.where mask changes results