triton icon indicating copy to clipboard operation
triton copied to clipboard

Wrong results where one of the args is assigned to constant inside the kernel

Open ngimel opened this issue 1 year ago • 7 comments

This might be related to #714. Repro (comments inside, requires torchdynamo unfortunately), tl;dr if the kernel has xnumel=<const> where xnumel is also a kernel arg, and is equal to the value of xnumel that is passed to the kernel (so should be a no-op, or even if it's used for optimization, shouldn't change results) it produces wrong results. Note that this is using the new runtime, with the old runtime both versions of the kernel produce wrong results.
I'm happy to provide generated ptx if needed, or any additional info, given that repro requires dynamo, although to get wrong results minor changes can be made to disable pre-compilation and lose dynamo dependency.

from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torchinductor.codecache import AsyncCompile

aten = torch.ops.aten
async_compile = AsyncCompile()

import triton
import triton.language as tl
from torchinductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


kernel1 = async_compile.triton('''
import triton
import triton.language as tl
from torchinductor.ir import ReductionHint
from torchinductor.triton_ops.autotune import pointwise
from torchinductor.utils import instance_descriptor

@pointwise(size_hints=[2048], filename=__file__, meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'constants': {}})
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x2 = xindex
    x0 = xindex % 1000
    x1 = (xindex // 1000)
    tmp0 = tl.load(in_ptr0 + x2, xmask)
    tmp1 = tl.load(in_ptr1 + x2, xmask)
    tl.store(out_ptr0 + x0 + (1000*tmp0) + (196000*x1) + tl.zeros([XBLOCK], tl.int32), tmp1, xmask)
''')

kernel2 = async_compile.triton('''
import triton
import triton.language as tl
from torchinductor.ir import ReductionHint
from torchinductor.triton_ops.autotune import pointwise
from torchinductor.utils import instance_descriptor

@pointwise(size_hints=[2048], filename=__file__, meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'constants': {}})
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x2 = xindex
    x0 = xindex % 1000
    x1 = (xindex // 1000)
    tmp0 = tl.load(in_ptr0 + x2, xmask)
    tmp1 = tl.load(in_ptr1 + x2, xmask)
    tl.store(out_ptr0 + x0 + (1000*tmp0) + (196000*x1) + tl.zeros([XBLOCK], tl.int32), tmp1, xmask)
''')
async_compile.wait(globals())
del async_compile

def call(arg4_1, arg7_1):
    s1 = 2
    buf0 = empty_strided((2, 196, 1000), (196000, 1000, 1), device='cuda', dtype=torch.float32).fill_(0)
    stream0 = get_cuda_stream(0)
    kernel1_xnumel = 1000*s1
    kernel1.run(arg4_1, arg7_1, buf0, kernel1_xnumel, grid=grid(kernel1_xnumel), stream=stream0)
    #kernel1[grid(kernel1_xnumel)](arg4_1, arg7_1, buf0, kernel1_xnumel, 1024)
    print(buf0[0].amax(-1)[:10], buf0.sum()) #no xnumel=2000 in the kernel, correct answer
    buf0 = empty_strided((2, 196, 1000), (196000, 1000, 1), device='cuda', dtype=torch.float32).fill_(0)
    kernel2_xnumel = 1000*s1
    kernel2.run(arg4_1, arg7_1, buf0, kernel1_xnumel, grid=grid(kernel2_xnumel), stream=stream0)
    #kernel2[grid(kernel2_xnumel)](arg4_1, arg7_1, buf0, kernel2_xnumel, 1024)

    print(buf0[0].amax(-1)[:10], buf0.sum()) #xnumel=2000 in the kernel, wrong answer
    return (buf0, )


if __name__ == "__main__":
    from torchdynamo.testing import rand_strided
    from torchinductor.utils import print_performance
    torch.manual_seed(12345)
    arg = rand_strided((2, 1, 1000), (1000, 1000, 1), device='cuda', dtype=torch.int64)
    arg4_1 = torch.arange(2000, device="cuda", dtype=torch.int64).reshape(2,1,1000) % 196 #torch.randint(196, arg.size(), device=arg.device, dtype = arg.dtype)
    arg7_1 = torch.arange(1, 2001, device="cuda", dtype=torch.float32).reshape(2, 1000)
    buf = torch.zeros(2, 196, 1000, device="cuda")
    buf.scatter_(1, arg4_1, arg7_1.reshape(2, 1, 1000))
    print(buf[0].amax(-1)[:10]) #correct answer
    call(arg4_1, arg7_1)

Output:

tensor([981., 982., 983., 984., 985., 986., 987., 988., 989., 990.],
       device='cuda:0') #correct, output of torch.Tensor.scatter
tensor([981., 982., 983., 984., 985., 986., 987., 988., 989., 990.],
       device='cuda:0') tensor(2001000., device='cuda:0') #correct, w/o xnumel=2000 in the kernel
tensor([984.,   0.,   0.,   0., 988.,   0.,   0.,   0., 992.,   0.],
       device='cuda:0') tensor(2001000., device='cuda:0') #wrong, xnumel=2000 in the kernel

ngimel avatar Oct 05 '22 05:10 ngimel