triton icon indicating copy to clipboard operation
triton copied to clipboard

`tl.cdiv` gives incorrect results for negative numbers

Open CHDev93 opened this issue 2 years ago • 0 comments

🐛 Describe the bug

I'm observing that triton gives the incorrect answer when computing the ceiling division of a negative integer. triton.cdiv gives the correct answer but tl.cdiv (on the exact same inputs) give the wrong answer. Trying to write my own equivalent of tl.cdiv(a, b) as (a + b - 1) // b also yields the same incorrect answer

Versions

  • triton==2.0.0.dev20221025
  • torch==1.12.1+cu116
  • NVIDIA A100-SXM4-80GB

Reproducer

I've adapted the code below from the low memory dropout example. It should be taking the ceiling of -7.2 and returning -7 but instead produces -6

import torch
import triton
import triton.language as tl


@triton.jit
def _cdiv_bug(
    output_ptr,
    n_rows,
    n_cols,
    w,
    lam,
    T2,
    stride_xn,
    stride_xm,
    BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    row_id = tl.program_id(axis=0)
    col_id = tl.program_id(axis=1)

    row_offsets = row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    col_offsets = col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    x_offsets = row_offsets[:, None] * stride_xn + col_offsets[None, :] * stride_xm

    # load data from x
    mask = (row_offsets[:, None] < n_rows) & (col_offsets[None, :] < n_cols)

    # write-back
    a = w + lam - 1
    v = tl.cdiv(1 + w - T2 - a, lam)
    # v = tl.cdiv(-72, 10)  # interestingly this will correctly compute -7 as the answer
    tl.store(output_ptr + x_offsets, v, mask=mask)


def cdiv_bug(x: torch.Tensor) -> torch.Tensor:
    BLOCK_SIZE = 16
    output = torch.empty_like(x)
    assert x.is_contiguous()

    n_rows, n_cols = x.shape
    grid_x = triton.cdiv(n_rows, BLOCK_SIZE)
    grid_y = grid_x
    grid = (grid_x, grid_y)

    r = 2
    w = 8
    T1 = 32
    T2 = T1 * r
    lam = 10

    a = w + lam - 1
    numerator = 1 + w - T2 - a
    denominator = lam
    expected_answer = triton.cdiv(numerator, denominator)
    print(f"{expected_answer=}")

    _cdiv_bug[grid](
        output,
        n_rows,
        n_cols,
        w,
        lam,
        T2,
        x.stride(0),
        x.stride(1),
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return output


def main() -> None:
    x = torch.rand(size=(16, 16)).cuda()
    ceil_div_output = cdiv_bug(x)
    triton_answer = ceil_div_output[0, 0].item()

    print(f"{triton_answer=}")


main()
# Program output
# expected_answer=-7
# triton_answer=-6.0

CHDev93 avatar Dec 06 '22 12:12 CHDev93