triton icon indicating copy to clipboard operation
triton copied to clipboard

Questions to tl.device_print function, something went wrong?

Open foreverpiano opened this issue 9 months ago • 4 comments

For the matmul example code, I want to print the offset when set blocksize=16 and M = 30, meaning the last block doesn't have enough datas. I use tl.device_print, but it returns pid ... idx ... bn: x, where x is an integer. In my view, x should be a tensor like [0,1,2,3] or [4,5,6,0] (0 means out of index) when set blocksize=4, and m=7. Am I correct?

@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr, 
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # hyperparameters
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    ACTIVATION: tl.constexpr,
): 
 ...
    # Compute ptrs
    # [1,9] * [9,1] -> (pid_m, pid_n)
    # 8*8@8*8->8*8, block4*4, consider 2nd block
    offset_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M # [4,5,6,7]
    offset_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N # [0,1,2,3]
    tl.device_print("am", offset_am)
    tl.device_print("bn", offset_bn)
    offset_k = tl.arange(0, BLOCK_SIZE_K) # [0,1,2,3]
    a_ptrs = a_ptr + offset_am[:, None] * stride_am + offset_k[None, :] * stride_ak # broadcast to 2D
    b_ptrs = b_ptr + offset_k[:, None] * stride_bk + offset_bn[None, :] * stride_bn # broadcast to 2D
    
    

print log

pid (4, 0, 0) idx ( 0) am: 64
pid (4, 0, 0) idx ( 1) am: 65
pid (4, 0, 0) idx ( 2) am: 66
pid (4, 0, 0) idx ( 3) am: 67
pid (4, 0, 0) idx ( 4) am: 68
pid (4, 0, 0) idx ( 5) am: 69
pid (4, 0, 0) idx ( 6) am: 70
pid (4, 0, 0) idx ( 7) am: 71
pid (4, 0, 0) idx ( 8) am: 72

I don't really know the pid and idx meaning. Is there any explanation to the parameters.?

foreverpiano avatar Apr 30 '24 10:04 foreverpiano

@ThomasRaoux @zhanglx13

foreverpiano avatar Apr 30 '24 10:04 foreverpiano

In general, pid refers to program id (triton term) or block id (cuda term) or workgroup id (hip term). idx refers to the index of the value in the tensor. In your case, block_size = 16, so each program deals with a vector (tensor in general) of 16 elements. Then program 4 (the 5th program) deals with a vector with elements: 64, 65, ... For more information about printOp, check the comments in it's implementation

zhanglx13 avatar Apr 30 '24 14:04 zhanglx13

Thanks for you reply. I am confused about this line offset_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M When M % BLOCK_SIZE_M !=0 (means the last block doesn't have enough data), is the offset_am[M,((M/BLOCK_SIZE_M)+1)*BLOCK_SIZE_M)] we should ignore?

foreverpiano avatar Apr 30 '24 14:04 foreverpiano

The offset_am is computed like this

  1. pid_m * BLOCK_SIZE_M is a scalar. Let's assume it's 1 * 16 = 16
  2. tl.arange(0, BLOCK_SIZE_M) makes a vector (tensor) of [0, 1, ..., 15]
  3. Then you add them and get: [16, 17, ..., 31]
  4. At last, you apply the %M (say M=30) part and get: [16, 17, 18, ..., 29, 0, 1]

As you can see, you still have a vector of 16 elements, but the last two values are "wrapped" so that we don't access out-of-bound elements.

zhanglx13 avatar Apr 30 '24 14:04 zhanglx13