triton icon indicating copy to clipboard operation
triton copied to clipboard

Accessing slices of a tensor

Open yd2102 opened this issue 3 years ago • 17 comments

Hi,

I am trying to modify the LayerNorm sample code in such a way that the intermediate computations are stored in shared memory to avoid recomputation. That is, I want to store the result of "(a - mean)" to shared memory. I am trying to experiment if saving the result benefits the performance in cases where the normalization dimension is small (but still larger than block size).

Basically what I did was like this:

    tmp = tl.zeros([TMP_LEN], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
        a = tl.where(cols < N, a - mean, 0.)
        _var += a * a
        # save calculated result in float32
        tmp[off:(off+BLOCK_SIZE)] = a

But the Triton compiler gave the following error:

(call stack omitted)
  File "/home/yidoe/triton/python/triton/code_gen.py", line 347, in visit_Assign
    _names += [self.visit(target)]
  File "/home/yidoe/triton/python/triton/code_gen.py", line 754, in visit
    return super().visit(node)
  File "/home/yidoe/anaconda3/lib/python3.9/ast.py", line 407, in visit
    return visitor(node)
  File "/home/yidoe/triton/python/triton/code_gen.py", line 559, in visit_Subscript
    assert node.ctx.__class__.__name__ == "Load"
AssertionError

Does it mean slicing the tensor this way isn't supported? Also is my understanding that "tmp" would be allocated on shared memory correct? Thank you!

The full kernel is here:

@triton.jit
def _layer_norm_fwd_fused_dedup(
    Out,
    A,
    Weight,
    Bias,
    Mean, Rstd,
    stride, N, eps,
    TMP_LEN: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    # position of elements processed by this program
    row = tl.program_id(0)
    Out += row * stride
    A += row * stride
    # compute mean
    mean = 0
    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
        _mean += a
    mean = tl.sum(_mean, axis=0) / N
    # compute variance
    _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    # allocate shared memory to save calculated result
    tmp = tl.zeros([TMP_LEN], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
        a = tl.where(cols < N, a - mean, 0.)
        _var += a * a
        # save calculated result in float32
        tmp[off:(off+BLOCK_SIZE)] = a
    var = tl.sum(_var, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    # write-back mean/rstd
    tl.store(Mean + row, mean)
    tl.store(Rstd + row, rstd)
    # multiply by weight and add bias
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < N
        weight = tl.load(Weight + cols, mask=mask)
        bias = tl.load(Bias + cols, mask=mask)
        a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
        a_hat = (a - mean) * rstd
        out = a_hat * weight + bias
        # # write-back
        tl.store(Out + cols, out, mask=mask)

yd2102 avatar Sep 15 '22 00:09 yd2102

Yeah, slices aren't yet supported in Triton.

ptillet avatar Sep 15 '22 00:09 ptillet

Thank you! Two more questions:

  1. Is there plan to support tensor slicing in Triton kernels? It allows for granular access to those allocated tensors.
  2. When I allocate tensor like this tmp = tl.zeros([TMP_LEN]) is it placed on shared memory and spilled to HBM only when it doesn't fit in shared memory?

yd2102 avatar Sep 15 '22 17:09 yd2102

  1. It won't be straightforward as we'll have to exchange data across threads. @ptillet I think it might be possible with a convert layout op that stores data to the shared memory first then grabs data from the shared memory.
  2. Registers. You don't explicitly allocate any shared memory data.

Jokeren avatar Sep 15 '22 18:09 Jokeren

There is a plan to support them, probably after triton-mlir is merged. As @Jokeren mentioned, we could probably get some slow support working pretty easily, but getting it right (i.e., using warp-shuffles when appropriate, or even just shuffling data within the same cuda thread when applicable) will take time.

ptillet avatar Sep 15 '22 23:09 ptillet

There is a plan to support them, probably after triton-mlir is merged. As @Jokeren mentioned, we could probably get some slow support working pretty easily, but getting it right (i.e., using warp-shuffles when appropriate, or even just shuffling data within the same cuda thread when applicable) will take time.

Is there any update on tensor slicing? looking forward to this support

mydmdm avatar Feb 10 '23 11:02 mydmdm

Unfortunately we are busy working on torchinductor integration, haven't made progress so far on this topic.

Will update you in the near future.

Jokeren avatar Feb 10 '23 16:02 Jokeren

There is a plan to support them, probably after triton-mlir is merged. As @Jokeren mentioned, we could probably get some slow support working pretty easily, but getting it right (i.e., using warp-shuffles when appropriate, or even just shuffling data within the same cuda thread when applicable) will take time.

Is there any update on tensor slicing? looking forward to this support

haochengxi avatar Dec 02 '23 15:12 haochengxi

Unfortunately not yet

Jokeren avatar Dec 02 '23 16:12 Jokeren

there is a functional workaround invented by @apgoucher. You can see an example here: https://github.com/openai/triton/blob/f2bc68ec0d92771368693c138566b555720fe389/python/triton/language/standard.py#L310C1-L310C1

ThomasRaoux avatar Dec 02 '23 16:12 ThomasRaoux

There is a plan to support them, probably after triton-mlir is merged. As @Jokeren mentioned, we could probably get some slow support working pretty easily, but getting it right (i.e., using warp-shuffles when appropriate, or even just shuffling data within the same cuda thread when applicable) will take time.

Is there any update on tensor slicing? looking forward to this support

horrorChen avatar Apr 21 '24 08:04 horrorChen

also looking for support on slicing tensor in shared memory

hgl71964 avatar Apr 21 '24 11:04 hgl71964

Is there any recent plan to implent tensor slicing?

leeeizhang avatar May 29 '24 05:05 leeeizhang