marlin icon indicating copy to clipboard operation
marlin copied to clipboard

questions about slice_col_par

Open Lenan22 opened this issue 2 years ago • 2 comments

` int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par; // int slice_iters; // number of threadblock tiles in the current slice int slice_count = 0; // total number of active threadblocks in the current slice int slice_idx; // index of threadblock in current slice; numbered bottom to top

if (slice_col_par >= n_tiles) {

` I have some questions about the code above. For example, if there are 108 SMs on the GPU and the calculated iters is 19, with blockIdx.x ranging from 0 to 127, is slice_col_par directly calculated based on iters=19? For instance, when blockIdx.x=5 or others, this thread block might not iterate 19 times.

Lenan22 avatar Apr 07 '24 13:04 Lenan22

If the batchsize is larger than 64, we essentially process multiple batchsize 64 matmuls in a single kernel invocations (to allow better partitioning). This is done by virtually replicating the matrix. Consider this example:parallel = 2, a matrix that partitions into 4 tiles and 3 SMs:

SM -> tile assignment:

00 12
01 12

01 01 // slice_col
01 23 // slice_col_par

slice_col points to the actual column in the matrix and slice_col_par to the column in the virtually replicated version.

Yes, it can happen that a few SMS (here SM 2) process less tiles than others; however, the distribution should usually be quite even since our partitioning is designed so that one SM can partially process multiple columns (see SM 0 or SM 1 above).

efrantar avatar Apr 07 '24 13:04 efrantar

Thanks a lot

Lenan22 avatar Apr 10 '24 11:04 Lenan22