triton
triton copied to clipboard
Question: Layer Norm tutorial 05
I think I am missing something when going over the Triton implementation of LayerNorm in the tutorial.
In the forward, we compute the mean as:
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
But below, we enforce N <= BLOCK_SIZE
:
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
So I am a bit confused by the implementation of the forward. Couldn't we do something as simple as:
cols = tl.arange(0, BLOCK_SIZE)
mean = tl.sum(tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)) / N
I think I might be missing a subtlety here. Is there a situation where the more complex implementation in the tutorial is useful, because even if I use a BLOCK_SIZE < N
, I feel like I would launch more kernels instead of going through the for loop over range(0, N, BLOCK_SIZE)
.
Very curious to have your thoughts, thank you!