triton
triton copied to clipboard
`tl.cdiv` gives incorrect results for negative numbers
🐛 Describe the bug
I'm observing that triton gives the incorrect answer when computing the ceiling division of a negative integer. triton.cdiv
gives the correct answer but tl.cdiv
(on the exact same inputs) give the wrong answer. Trying to write my own equivalent of tl.cdiv(a, b)
as (a + b - 1) // b
also yields the same incorrect answer
Versions
-
triton==2.0.0.dev20221025
-
torch==1.12.1+cu116
- NVIDIA A100-SXM4-80GB
Reproducer
I've adapted the code below from the low memory dropout example. It should be taking the ceiling of -7.2 and returning -7 but instead produces -6
import torch
import triton
import triton.language as tl
@triton.jit
def _cdiv_bug(
output_ptr,
n_rows,
n_cols,
w,
lam,
T2,
stride_xn,
stride_xm,
BLOCK_SIZE: tl.constexpr,
):
# compute memory offsets of elements handled by this instance
row_id = tl.program_id(axis=0)
col_id = tl.program_id(axis=1)
row_offsets = row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
col_offsets = col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x_offsets = row_offsets[:, None] * stride_xn + col_offsets[None, :] * stride_xm
# load data from x
mask = (row_offsets[:, None] < n_rows) & (col_offsets[None, :] < n_cols)
# write-back
a = w + lam - 1
v = tl.cdiv(1 + w - T2 - a, lam)
# v = tl.cdiv(-72, 10) # interestingly this will correctly compute -7 as the answer
tl.store(output_ptr + x_offsets, v, mask=mask)
def cdiv_bug(x: torch.Tensor) -> torch.Tensor:
BLOCK_SIZE = 16
output = torch.empty_like(x)
assert x.is_contiguous()
n_rows, n_cols = x.shape
grid_x = triton.cdiv(n_rows, BLOCK_SIZE)
grid_y = grid_x
grid = (grid_x, grid_y)
r = 2
w = 8
T1 = 32
T2 = T1 * r
lam = 10
a = w + lam - 1
numerator = 1 + w - T2 - a
denominator = lam
expected_answer = triton.cdiv(numerator, denominator)
print(f"{expected_answer=}")
_cdiv_bug[grid](
output,
n_rows,
n_cols,
w,
lam,
T2,
x.stride(0),
x.stride(1),
BLOCK_SIZE=BLOCK_SIZE,
)
return output
def main() -> None:
x = torch.rand(size=(16, 16)).cuda()
ceil_div_output = cdiv_bug(x)
triton_answer = ceil_div_output[0, 0].item()
print(f"{triton_answer=}")
main()
# Program output
# expected_answer=-7
# triton_answer=-6.0