triton
triton copied to clipboard
Wrong results where one of the args is assigned to constant inside the kernel
This might be related to #714. Repro (comments inside, requires torchdynamo unfortunately), tl;dr if the kernel has xnumel=<const>
where xnumel
is also a kernel arg, and xnumel
that is passed to the kernel (so should be a no-op, or even if it's used for optimization, shouldn't change results) it produces wrong results. Note that this is using the new runtime, with the old runtime both versions of the kernel produce wrong results.
I'm happy to provide generated ptx if needed, or any additional info, given that repro requires dynamo, although to get wrong results minor changes can be made to disable pre-compilation and lose dynamo dependency.
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torchinductor.codecache import AsyncCompile
aten = torch.ops.aten
async_compile = AsyncCompile()
import triton
import triton.language as tl
from torchinductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
kernel1 = async_compile.triton('''
import triton
import triton.language as tl
from torchinductor.ir import ReductionHint
from torchinductor.triton_ops.autotune import pointwise
from torchinductor.utils import instance_descriptor
@pointwise(size_hints=[2048], filename=__file__, meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'constants': {}})
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x2 = xindex
x0 = xindex % 1000
x1 = (xindex // 1000)
tmp0 = tl.load(in_ptr0 + x2, xmask)
tmp1 = tl.load(in_ptr1 + x2, xmask)
tl.store(out_ptr0 + x0 + (1000*tmp0) + (196000*x1) + tl.zeros([XBLOCK], tl.int32), tmp1, xmask)
''')
kernel2 = async_compile.triton('''
import triton
import triton.language as tl
from torchinductor.ir import ReductionHint
from torchinductor.triton_ops.autotune import pointwise
from torchinductor.utils import instance_descriptor
@pointwise(size_hints=[2048], filename=__file__, meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], 'constants': {}})
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x2 = xindex
x0 = xindex % 1000
x1 = (xindex // 1000)
tmp0 = tl.load(in_ptr0 + x2, xmask)
tmp1 = tl.load(in_ptr1 + x2, xmask)
tl.store(out_ptr0 + x0 + (1000*tmp0) + (196000*x1) + tl.zeros([XBLOCK], tl.int32), tmp1, xmask)
''')
async_compile.wait(globals())
del async_compile
def call(arg4_1, arg7_1):
s1 = 2
buf0 = empty_strided((2, 196, 1000), (196000, 1000, 1), device='cuda', dtype=torch.float32).fill_(0)
stream0 = get_cuda_stream(0)
kernel1_xnumel = 1000*s1
kernel1.run(arg4_1, arg7_1, buf0, kernel1_xnumel, grid=grid(kernel1_xnumel), stream=stream0)
#kernel1[grid(kernel1_xnumel)](arg4_1, arg7_1, buf0, kernel1_xnumel, 1024)
print(buf0[0].amax(-1)[:10], buf0.sum()) #no xnumel=2000 in the kernel, correct answer
buf0 = empty_strided((2, 196, 1000), (196000, 1000, 1), device='cuda', dtype=torch.float32).fill_(0)
kernel2_xnumel = 1000*s1
kernel2.run(arg4_1, arg7_1, buf0, kernel1_xnumel, grid=grid(kernel2_xnumel), stream=stream0)
#kernel2[grid(kernel2_xnumel)](arg4_1, arg7_1, buf0, kernel2_xnumel, 1024)
print(buf0[0].amax(-1)[:10], buf0.sum()) #xnumel=2000 in the kernel, wrong answer
return (buf0, )
if __name__ == "__main__":
from torchdynamo.testing import rand_strided
from torchinductor.utils import print_performance
torch.manual_seed(12345)
arg = rand_strided((2, 1, 1000), (1000, 1000, 1), device='cuda', dtype=torch.int64)
arg4_1 = torch.arange(2000, device="cuda", dtype=torch.int64).reshape(2,1,1000) % 196 #torch.randint(196, arg.size(), device=arg.device, dtype = arg.dtype)
arg7_1 = torch.arange(1, 2001, device="cuda", dtype=torch.float32).reshape(2, 1000)
buf = torch.zeros(2, 196, 1000, device="cuda")
buf.scatter_(1, arg4_1, arg7_1.reshape(2, 1, 1000))
print(buf[0].amax(-1)[:10]) #correct answer
call(arg4_1, arg7_1)
Output:
tensor([981., 982., 983., 984., 985., 986., 987., 988., 989., 990.],
device='cuda:0') #correct, output of torch.Tensor.scatter
tensor([981., 982., 983., 984., 985., 986., 987., 988., 989., 990.],
device='cuda:0') tensor(2001000., device='cuda:0') #correct, w/o xnumel=2000 in the kernel
tensor([984., 0., 0., 0., 988., 0., 0., 0., 992., 0.],
device='cuda:0') tensor(2001000., device='cuda:0') #wrong, xnumel=2000 in the kernel