triton
triton copied to clipboard
Questions to tl.device_print function, something went wrong?
For the matmul example code, I want to print the offset when set blocksize=16 and M = 30, meaning the last block doesn't have enough datas.
I use tl.device_print
, but it returns pid ... idx ... bn: x, where x is an integer. In my view, x should be a tensor like [0,1,2,3] or [4,5,6,0] (0 means out of index) when set blocksize=4, and m=7. Am I correct?
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# hyperparameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
...
# Compute ptrs
# [1,9] * [9,1] -> (pid_m, pid_n)
# 8*8@8*8->8*8, block4*4, consider 2nd block
offset_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M # [4,5,6,7]
offset_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N # [0,1,2,3]
tl.device_print("am", offset_am)
tl.device_print("bn", offset_bn)
offset_k = tl.arange(0, BLOCK_SIZE_K) # [0,1,2,3]
a_ptrs = a_ptr + offset_am[:, None] * stride_am + offset_k[None, :] * stride_ak # broadcast to 2D
b_ptrs = b_ptr + offset_k[:, None] * stride_bk + offset_bn[None, :] * stride_bn # broadcast to 2D
print log
pid (4, 0, 0) idx ( 0) am: 64
pid (4, 0, 0) idx ( 1) am: 65
pid (4, 0, 0) idx ( 2) am: 66
pid (4, 0, 0) idx ( 3) am: 67
pid (4, 0, 0) idx ( 4) am: 68
pid (4, 0, 0) idx ( 5) am: 69
pid (4, 0, 0) idx ( 6) am: 70
pid (4, 0, 0) idx ( 7) am: 71
pid (4, 0, 0) idx ( 8) am: 72
I don't really know the pid and idx meaning. Is there any explanation to the parameters.?
@ThomasRaoux @zhanglx13
In general, pid
refers to program id (triton term) or block id (cuda term) or workgroup id (hip term). idx
refers to the index of the value in the tensor.
In your case, block_size = 16, so each program deals with a vector (tensor in general) of 16 elements. Then program 4 (the 5th program) deals with a vector with elements: 64, 65, ...
For more information about printOp, check the comments in it's implementation
Thanks for you reply. I am confused about this line offset_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
When M % BLOCK_SIZE_M !=0 (means the last block doesn't have enough data), is the offset_am[M,((M/BLOCK_SIZE_M)+1)*BLOCK_SIZE_M)] we should ignore?
The offset_am is computed like this
- pid_m * BLOCK_SIZE_M is a scalar. Let's assume it's 1 * 16 = 16
- tl.arange(0, BLOCK_SIZE_M) makes a vector (tensor) of [0, 1, ..., 15]
- Then you add them and get: [16, 17, ..., 31]
- At last, you apply the
%M
(say M=30) part and get: [16, 17, 18, ..., 29, 0, 1]
As you can see, you still have a vector of 16 elements, but the last two values are "wrapped" so that we don't access out-of-bound elements.