triton
triton copied to clipboard
The `pid_m` calculation in matrix-multiplication tutorial
The code at line 216 defines pid_m
as follows:
pid_m = first_pid_m + (pid % group_size_m)
This line seems to be incorrect? Why we do not need an additional modulo operation as follows:
pid_m = first_pid_m + (pid % num_pid_in_group) % group_size_m
num_pid_in_group = GROUP_SIZE_M * num_pid_n
Because num_pid_in_group is a multiple of
GROUP_SIZE_M`, we don't need to do the mod twice?
Yes, but the group_size_m
is not always equal to GROUP_SIZE_M
in the last group?
Hm, you do have a good point...
I'm not 100% sure how this is supposed to work. Does the code do the wrong thing when we hit that branch?
num_pid_in_group >= group_size_m
Maybe you are confused about pid_m
's order in the last group? IIUC, it doesn't matter as long as they are unique
Hi @Pairshoe , From code readability side, it should be (pid % **num_pid_in_group**) % group_size_m
. From computation side, it doesn't matter. Here's a possible explanation (correct if there's something wrong).
Considering A = first_pid_m * num_pid_n, B is a variable within the range of [0, (num_pid_m - first_pid_m) * num_pid_n], while we can see A as constant for the last group, then for the last group,
pid % group_size_m
is equal to (A + B) % (num_pid_m - first_pid_m)
, and it must be in the range of [0, (num_pid_m - first_pid_m) ], just not the right order.