triton
triton copied to clipboard
is there a plan to add argsort?
one useful scan Op for topk is tl.sort(); however it doesn't returned the indices as in torch.sort (https://pytorch.org/docs/stable/generated/torch.sort.html)
May I know if there's a plan to essentially achieve argsort
kinds of operation?
nobody is working on it at the moment, however tl.sort is implemented as a standard library. So you should be able to implement your own version doing that.
so i try to look into it, and it seems I can also return the indices as returned by argsort by modify the tl.sort Op
Would you like to add it to the standard library? so if I can submit a PR, or having a tl.argsort is better?
so i try to look into it, and it seems I can also return the indices as returned by argsort by modify the tl.sort Op
Would you like to add it to the standard library? so if I can submit a PR, or having a tl.argsort is better?
If you can share a link to your implementation once it is done that would be great but at the time I'm not sure we want to include it as part of standard as we may have some changes to tl.sort in the besr future to make it more efficient.
the code snippet seems to work for me:
import sys, os
import torch
import triton
import triton.language as tl
import triton.language.core as core
from triton.language.standard import _log2, sum, zeros_like
@triton.jit
def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
n_outer: core.constexpr = x.numel >> n_dims
shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
y = core.reshape(x, shape)
# slice left/right with 'stride' 2**(n_dims - i - 1)
mask = core.arange(0, 2)[None, :, None]
left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
left = core.reshape(left, x.shape)
right = core.reshape(right, x.shape)
# idx
y_idx = core.reshape(ids, shape)
left_idx = core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape)
right_idx = core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape)
left_idx = core.reshape(left_idx, x.shape)
right_idx = core.reshape(right_idx, x.shape)
# actual compare-and-swap
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth,
signed=True)
ileft = left.to(idtype, bitcast=True)
iright = right.to(idtype, bitcast=True)
ix = x.to(idtype, bitcast=True)
cond = (left > right) ^ flip
ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
new_ids = ids ^ core.where(cond, left_idx ^ right_idx, zeros_like(ids))
return ret.to(x.dtype, bitcast=True), new_ids
@triton.jit
def _bitonic_merge(x, ids, stage: core.constexpr, order: core.constexpr,
n_dims: core.constexpr):
'''
order_type 0 == ascending
order_type 1 == descending
order_type 2 == alternating
'''
n_outer: core.constexpr = x.numel >> n_dims
core.static_assert(stage <= n_dims)
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
# descending order.
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
# a stride of 2) at this stage
if order == 2:
shape: core.constexpr = [
n_outer * 2**(n_dims - 1 - stage), 2, 2**stage
]
flip = core.reshape(
core.broadcast_to(core.arange(0, 2)[None, :, None], shape),
x.shape)
else:
flip = order
# perform `stage` rounds of `compare-and-swap`
for i in core.static_range(stage):
x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
return x, ids
@triton.jit
def argsort(x,
ids,
dim: core.constexpr = None,
descending: core.constexpr = core.CONSTEXPR_0):
# handle default dimension or check that it is the most minor dim
_dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
core.static_assert(_dim == len(x.shape) - 1,
"only minor dimension is currently supported")
# iteratively run bitonic merge-sort steps
n_dims: core.constexpr = _log2(x.shape[_dim])
for i in core.static_range(1, n_dims + 1):
x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending,
n_dims)
return x, ids
@triton.jit
def sort_kerenl(
# Pointers to matrices
x_ptr,
o_ptr,
id_ptr,
stride_m,
stride_n,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
m_offset = pid_m * stride_m * BLOCK_M
k_off = tl.arange(0, BLOCK_N)
x_ptrs = x_ptr + m_offset + (tl.arange(0, BLOCK_M)[:, None] * stride_m +
k_off[None, :])
# shape: [BLOCK_M, BLOCK_N]
x = tl.load(x_ptrs)
ids = tl.broadcast_to(tl.arange(0, BLOCK_N)[None, :], (BLOCK_M, BLOCK_N))
# o, ids = argsort(x, ids, 1, True)
o, ids = argsort(x, ids, 1, False)
o_ptrs = o_ptr + m_offset + (tl.arange(0, BLOCK_M)[:, None] * stride_m +
k_off[None, :])
id_ptrs = id_ptr + m_offset + (tl.arange(0, BLOCK_M)[:, None] * stride_m +
k_off[None, :])
tl.store(o_ptrs, o)
tl.store(id_ptrs, ids)
if __name__ == '__main__':
x = [
[0.9, 0.5, 0.2, 0.6],
[0.3, 0.1, 0.2, 0.2],
[0.3, 0.9, 0.2, 0.7],
[0.05, 0.1, 0.2, 0.002],
]
b = x
x = torch.tensor(
x,
dtype=torch.float16,
device='cuda',
)
o = torch.empty_like(x)
# ids = torch.empty(x.shape, dtype=torch.int, device='cuda')
ids = torch.empty(x.shape, dtype=torch.int64, device='cuda')
BLOCK_M = 2
BLOCK_N = 4
grid = (
triton.cdiv(x.shape[0], BLOCK_M),
triton.cdiv(x.shape[1], BLOCK_N),
)
k = sort_kerenl[grid](x, o, ids, x.stride(0), x.stride(1), BLOCK_M,
BLOCK_N)
# path = os.path.join(os.path.dirname(__file__), 'ttgir.mlir')
# with open(path, 'w') as f:
# f.write(k.asm['ttgir'])
#
# path = os.path.join(os.path.dirname(__file__), 'ttir.mlir')
# with open(path, 'w') as f:
# f.write(k.asm['ttir'])
print(k.asm.keys())
print('result: ')
print(o)
print('ids: ')
print(ids)
# ref_o, ref_ids = torch.sort(x, 1, True)
ref_o, ref_ids = torch.sort(x, 1, False)
print('ref: ')
print(ref_o)
print(ref_ids)
print(ref_ids.dtype)
# print('reconstruct: ')
# for i in range(len(b)):
# arr = b[i]
#
# reconstruct = [ arr[ids[i, j]] for j in range(len(arr))]
# print(reconstruct)
this looks great. Let's see later if we want to integrate it in standard. We also know that we are missing a place for extensive libraries. It would be a good fit for that, in the past we had talked about making another repo. Hopefully you are unblocked, let see in the future how to get better sharing.
also this tie-breaking policy does not agree with torch.sort(*, stable=True), because argsort is sensitive to duplicated elements