triton icon indicating copy to clipboard operation
triton copied to clipboard

Question: Layer Norm tutorial 05

Open Arnaud15 opened this issue 1 year ago • 0 comments

I think I am missing something when going over the Triton implementation of LayerNorm in the tutorial.

In the forward, we compute the mean as:

_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
    cols = off + tl.arange(0, BLOCK_SIZE)
    a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
    _mean += a
mean = tl.sum(_mean, axis=0) / N

But below, we enforce N <= BLOCK_SIZE:

if N > BLOCK_SIZE:
    raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

So I am a bit confused by the implementation of the forward. Couldn't we do something as simple as:

cols = tl.arange(0, BLOCK_SIZE)
mean = tl.sum(tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)) / N

I think I might be missing a subtlety here. Is there a situation where the more complex implementation in the tutorial is useful, because even if I use a BLOCK_SIZE < N, I feel like I would launch more kernels instead of going through the for loop over range(0, N, BLOCK_SIZE).

Very curious to have your thoughts, thank you!

Arnaud15 avatar Feb 14 '24 00:02 Arnaud15