triton icon indicating copy to clipboard operation
triton copied to clipboard

No `other` branch in tl.load leads to invalid results, even though all other values are masked out

Open ngimel opened this issue 3 years ago • 13 comments

Repro below. Generated ptx looks valid in both cases, with only difference in movs with @!pxx as expected . Happens with fp16, float32 is ok. I'm deliberately setting other to nan to make sure that it's all masked out, results are always non-nan. If other kwarg is not specified, first 2 columns of result are random. I've tried trimming down loads further, but it looks like 3 loads is the minimum to trigger this bug.

import torch

aten = torch.ops.aten

import triton
import triton.language as tl


@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x1 = (xindex // 8) % 8
    x0 = xindex % 8
    x4 = xindex
    tmp0 = (-1) + x1
    tmp1 = 0
    tmp2 = tmp0 >= tmp1
    tmp3 = 8
    tmp4 = tmp0 < tmp3
    tmp5 = tmp2 & tmp4
    tmp6 = (-1) + x0
    tmp7 = tmp6 >= tmp1
    tmp8 = tmp6 < tmp3
    tmp9 = tmp7 & tmp8
    tmp13 = x0
    tmp14 = tmp13 >= tmp1
    tmp15 = tmp13 < tmp3
    tmp16 = tmp14 & tmp15
    tmp21 = 1 + x0
    tmp22 = tmp21 >= tmp1
    tmp23 = tmp21 < tmp3
    tmp24 = tmp22 & tmp23
    tmp25 = tmp5 & tmp24
    tmp26 = tl.load(in_ptr0 + (-7) + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp25).to(tl.float32)
    tmp27 = tl.where(tmp25, tmp26, 0.0)
    tmp29 = x1
    tmp30 = tmp29 >= tmp1
    tmp31 = tmp29 < tmp3
    tmp32 = tmp30 & tmp31
    tmp33 = tmp32 & tmp9
   # no `other` kwarg leads to invalid results
    # tmp34 = tl.load(in_ptr0 + (-1) + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp33).to(tl.float32)
    tmp34 = tl.load(in_ptr0 + (-1) + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp33, other=float('nan')).to(tl.float32)
    tmp35 = tl.where(tmp33, tmp34, 0.0)
    tmp36 = tmp35 + tmp27
    tmp37 = tmp32 & tmp16
    # no `other` kwarg leads to invalid results
    tmp38 = tl.load(in_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp37).to(tl.float32)
    tmp38 = tl.load(in_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp37, other=float('nan')).to(tl.float32)
    tmp39 = tl.where(tmp37, tmp38, 0.0)
    tmp40 = tmp39 + tmp36
    tl.store(out_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), tmp40, xmask)




def call(arg0_1, arg1_1):
    arg0_1_size = arg0_1.size()
    s0 = arg0_1_size[0]
    buf0 = torch.empty_strided((s0, 2048, 8, 8), (131072, 64, 8, 1), device='cuda', dtype=torch.float16)
    kernel0_xnumel = 131072*s0
    pgm = kernel[(1,)](arg0_1, buf0, kernel0_xnumel, 1024)
    ptx = pgm.asm['ptx']
    with open("avgpool.ptx", "w") as f:
        for l in ptx:
            f.write(l)
    print(buf0[0,0])
    return (buf0, )


if __name__ == "__main__":
    from torchdynamo.testing import rand_strided
    from torchinductor.utils import print_performance
    arg0_1 = rand_strided((2, 2048, 8, 8), (131072, 64, 8, 1), device='cuda', dtype=torch.float16).fill_(1.)
    arg1_1 = rand_strided((192, 2048, 1, 1), (2048, 1, 1, 1), device='cuda', dtype=torch.float16).fill_(1.)
    call(arg0_1, arg1_1)

ngimel avatar Oct 04 '22 21:10 ngimel

@ngimel Can you try the master branch to see if this problem still exists?

Jokeren avatar Oct 12 '22 22:10 Jokeren

It seems to me a bit different from #745, since you said the problem appears when other is not used. Probably have to take a further look.

Jokeren avatar Oct 12 '22 23:10 Jokeren

Yes, the problem still exists on master. #745 is different, I hit #745 because I'm forced to use other even though I don't think I need to (other values should be masked anyway).

ngimel avatar Oct 13 '22 01:10 ngimel

Got you, I'll take a look later.

Jokeren avatar Oct 13 '22 01:10 Jokeren

Hi @ngimel, I have a question about the following two lines:

tmp26 = tl.load(in_ptr0 + (-7) + x4 + tl.zeros([XBLOCK], tl.int32), xmask & tmp25).to(tl.float32)
tmp27 = tl.where(tmp25, tmp26, 0.0)

Is there any false in xmask & tmp25? If so, my understanding is that the values loaded at false locations are undefined. In other words, the random value that stored on the corresponding registers.

Jokeren avatar Oct 17 '22 16:10 Jokeren

Specifying other = 0 should also yield the expected results.

Jokeren avatar Oct 17 '22 16:10 Jokeren

Yeah it's possible that undefined values are loaded, but they are masked out in the following lines so it shouldn't matter whether there's an other specified or not, that's what I demonstrate by specifying other=float('nan') and results aren't nan. Specifying any other yields expected results, again, because the particular value doesn't matter, it's masked out, but for the same reason not specifying other should work.

ngimel avatar Oct 17 '22 17:10 ngimel

[Sorry got my comment accidentally deleted]

In the triton programming model, I think the mask doesn't propagate across multiple instructions. @ptillet can you please confirm?

So for this instruction: tl.where(tmp33, tmp34, 0.0), if tmp33[0] == 1, it will select tmp34[0] which could be an undefined value.

Jokeren avatar Oct 17 '22 17:10 Jokeren

xmask is always True in this case (I launch just 1 program_id, so maximum xindex is 1023, xnumel is much larger). So the same mask, tmp33 is used for read and for tl.where, so tmp34 is never used in places where it's undefined (think of it this way - when I set other=float('nan') if that value was ever used later it would corrupt the output, but it doesn't).

ngimel avatar Oct 17 '22 18:10 ngimel

That's interesting, the problem went away by removing xmask.

tmp26 = tl.load(in_ptr0 + (-7) + x4 + tl.zeros([XBLOCK], tl.int32), tmp25).to(tl.float32) tmp34 = tl.load(in_ptr0 + (-1) + x4 + tl.zeros([XBLOCK], tl.int32), tmp33).to(tl.float32) tmp38 = tl.load(in_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), tmp37).to(tl.float32)

Jokeren avatar Oct 17 '22 18:10 Jokeren

Yeah I noticed that removing xmask changes results (I don't remember if it always fixed the problem), but that was very surprising too, as xmask is always true here.

ngimel avatar Oct 17 '22 18:10 ngimel

Hmmm 🤔 I think there are multiple options:

  • xmask isn't always true for some reason
  • removing xmask changes vectorization behavior. This kernel is fairly simple so I wouldn't really expect bugs there, but seeing the PTX wouldn't hurt :p
  • maybe a weird issue with &

ptillet avatar Oct 17 '22 19:10 ptillet

We found there's an issue with the generated ptx code. Register %rs9 is not used after it's defined. However, the problem doesn't exist in the generated llvm IR. Thus, we suspected this might be an issue of the PTX code gen logic in llvm. Since triton uses a relatively old llvm release (11.0), I tried to update llvm to recent releases but didn't get the problem solved with llvm-12.0 or llvm-13.0. I'm not able to make llvm >= 14.0 work because there are many API changes not compatible with the current triton backend. I think we could look back on this issue after MLIR integration is done since the triton-mlir branch uses a relatively newer llvm.

@%p9 ld.global.b16 {%rs9}, [ %rd12 + 0];
@%p10 ld.global.b16 {%rs10}, [ %rd12 + 2];
mov.b16     %h9, %rs10;
@%p11 ld.global.b16 {%rs11}, [ %rd12 + 4];
mov.b16     %h10, %rs11;
@%p12 ld.global.b16 {%rs12}, [ %rd12 + 6];
mov.b16     %h11, %rs12;
@%p13 ld.global.b16 {%rs13}, [ %rd12 + 8];
mov.b16     %h12, %rs13;
@%p14 ld.global.b16 {%rs14}, [ %rd12 + 10];
mov.b16     %h13, %rs14;
@%p15 ld.global.b16 {%rs15}, [ %rd12 + 12];
mov.b16     %h14, %rs15;
@%p16 ld.global.b16 {%rs16}, [ %rd12 + 14];
mov.b16     %h15, %rs16;

Jokeren avatar Oct 18 '22 00:10 Jokeren

A new example:

w/ and w/o other=0.0 yield different results.

from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
​
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
​
import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
​
triton__5 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_ops.autotune import reduction
from torch._inductor.utils import instance_descriptor
​
@reduction(size_hints=[8192, 256],
              reduction_hint=ReductionHint.INNER,
              filename=__file__,
              meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())]})
@triton.jit
def triton__5(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp4 = (tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf")).to(tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (197*x0)), rmask & xmask).to(tl.float32)
        tmp1 = 0.125
        tmp2 = tmp0 * tmp1
        tmp3 = tmp2.to(tl.float32)
        tmp5 = _tmp4 < tmp3
        _tmp4 = tl.where(rmask & xmask & (_tmp4 < tmp3), tmp3, _tmp4)
        tl.store(out_ptr2 + (r1 + (197*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
''')
​
async_compile.wait(globals())
del async_compile
​
torch.manual_seed(123)
stream0 = get_cuda_stream(0)
s0 = 4
buf11 = torch.randn((6*s0, 197, 197), device='cuda', dtype=torch.float16)
buf14 = torch.randn((s0, 6, 197, 197), device='cuda', dtype=torch.float16)
triton__5_xnumel = 1182*s0
triton__5.run(buf11, buf14, triton__5_xnumel, 197, grid=grid(triton__5_xnumel), stream=stream0)
print(buf14.mean())

Jokeren avatar Feb 13 '23 19:02 Jokeren

And again removing (_tmp4 < tmp3) from tl.where mask changes results

ngimel avatar Feb 13 '23 20:02 ngimel