triton
triton copied to clipboard
Accessing slices of a tensor
Hi,
I am trying to modify the LayerNorm sample code in such a way that the intermediate computations are stored in shared memory to avoid recomputation. That is, I want to store the result of "(a - mean)" to shared memory. I am trying to experiment if saving the result benefits the performance in cases where the normalization dimension is small (but still larger than block size).
Basically what I did was like this:
tmp = tl.zeros([TMP_LEN], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
# save calculated result in float32
tmp[off:(off+BLOCK_SIZE)] = a
But the Triton compiler gave the following error:
(call stack omitted)
File "/home/yidoe/triton/python/triton/code_gen.py", line 347, in visit_Assign
_names += [self.visit(target)]
File "/home/yidoe/triton/python/triton/code_gen.py", line 754, in visit
return super().visit(node)
File "/home/yidoe/anaconda3/lib/python3.9/ast.py", line 407, in visit
return visitor(node)
File "/home/yidoe/triton/python/triton/code_gen.py", line 559, in visit_Subscript
assert node.ctx.__class__.__name__ == "Load"
AssertionError
Does it mean slicing the tensor this way isn't supported? Also is my understanding that "tmp" would be allocated on shared memory correct? Thank you!
The full kernel is here:
@triton.jit
def _layer_norm_fwd_fused_dedup(
Out,
A,
Weight,
Bias,
Mean, Rstd,
stride, N, eps,
TMP_LEN: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
# allocate shared memory to save calculated result
tmp = tl.zeros([TMP_LEN], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.)
_var += a * a
# save calculated result in float32
tmp[off:(off+BLOCK_SIZE)] = a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# write-back mean/rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0.).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
Yeah, slices aren't yet supported in Triton.
Thank you! Two more questions:
- Is there plan to support tensor slicing in Triton kernels? It allows for granular access to those allocated tensors.
- When I allocate tensor like this
tmp = tl.zeros([TMP_LEN])is it placed on shared memory and spilled to HBM only when it doesn't fit in shared memory?
- It won't be straightforward as we'll have to exchange data across threads. @ptillet I think it might be possible with a convert layout op that stores data to the shared memory first then grabs data from the shared memory.
- Registers. You don't explicitly allocate any shared memory data.
There is a plan to support them, probably after triton-mlir is merged. As @Jokeren mentioned, we could probably get some slow support working pretty easily, but getting it right (i.e., using warp-shuffles when appropriate, or even just shuffling data within the same cuda thread when applicable) will take time.
There is a plan to support them, probably after triton-mlir is merged. As @Jokeren mentioned, we could probably get some slow support working pretty easily, but getting it right (i.e., using warp-shuffles when appropriate, or even just shuffling data within the same cuda thread when applicable) will take time.
Is there any update on tensor slicing? looking forward to this support
Unfortunately we are busy working on torchinductor integration, haven't made progress so far on this topic.
Will update you in the near future.
There is a plan to support them, probably after triton-mlir is merged. As @Jokeren mentioned, we could probably get some slow support working pretty easily, but getting it right (i.e., using warp-shuffles when appropriate, or even just shuffling data within the same cuda thread when applicable) will take time.
Is there any update on tensor slicing? looking forward to this support
Unfortunately not yet
there is a functional workaround invented by @apgoucher. You can see an example here: https://github.com/openai/triton/blob/f2bc68ec0d92771368693c138566b555720fe389/python/triton/language/standard.py#L310C1-L310C1
There is a plan to support them, probably after triton-mlir is merged. As @Jokeren mentioned, we could probably get some slow support working pretty easily, but getting it right (i.e., using warp-shuffles when appropriate, or even just shuffling data within the same cuda thread when applicable) will take time.
Is there any update on tensor slicing? looking forward to this support
also looking for support on slicing tensor in shared memory
Is there any recent plan to implent tensor slicing?