triton icon indicating copy to clipboard operation
triton copied to clipboard

`//` operator rounds towards zero

Open CHDev93 opened this issue 2 years ago • 0 comments

🐛 Describe the bug

Both numpy and python do integer division rounding towards minus infinity (floor division).

>> (-1) // 2
-1
>> import numpy as np
>> np.arange(-5,-1) // 2
array([-3, -2, -2, -1])

Torch does round toward 0 (like triton) but prints an explicit warning about this being deprecated. I think the current behaviour is rather unexpected. Maybe there should be a function (like torch.div) that allows the user to explicitly decide on the rounding behaviour? My current use case I very much rely on floor division, even in the case of negative numbers.

Versions

  • triton==2.0.0.dev20221025
  • torch==1.12.1+cu116
  • NVIDIA A100-SXM4-80GB

Reproducer

import torch
import triton
import triton.language as tl


@triton.jit
def _integer_div(
    output_ptr,
    n_rows,
    n_cols,
    w,
    r,
    stride_xn,
    stride_xm,
    BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    row_id = tl.program_id(axis=0)
    col_id = tl.program_id(axis=1)

    row_offsets = row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    col_offsets = col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    x_offsets = row_offsets[:, None] * stride_xn + col_offsets[None, :] * stride_xm

    # load data from x
    mask = (row_offsets[:, None] < n_rows) & (col_offsets[None, :] < n_cols)

    # write-back
    tl.store(output_ptr + x_offsets, (-w) // r, mask=mask)


def integer_div(x: torch.Tensor) -> torch.Tensor:
    BLOCK_SIZE = 16
    output = torch.empty_like(x)
    assert x.is_contiguous()

    n_rows, n_cols = x.shape
    grid_x = triton.cdiv(n_rows, BLOCK_SIZE)
    grid_y = grid_x
    grid = (grid_x, grid_y)

    r = 5
    w = 8

    expected_answer = (-w) // r
    print(f"{expected_answer=}")

    _integer_div[grid](
        output,
        n_rows,
        n_cols,
        w,
        r,
        x.stride(0),
        x.stride(1),
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return output


def main() -> None:
    x = torch.rand(size=(16, 16)).cuda()
    out = integer_div(x)
    triton_answer = out[0, 0].item()
    print(f"{triton_answer=}")


main()
# Program output
# expected_answer=-2
# triton_answer=-1.0

CHDev93 avatar Dec 06 '22 16:12 CHDev93