cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] How slice K reduce the value?

Open Arsmart123 opened this issue 3 years ago • 8 comments

Hi! I am learning 'tall' matmul and find it hard to find the code describing how slice K reduce the value.... I think, each wrap will calculate 32*64 values (each thread cal 64 values) and then....we want to reduce between different wraps? So we can not use wrap_sync function like shlf_down_sync? We will use atomicAdd to global memory?

Thank you!!!

Haha, actually I am guessing how cublas works using nsight compute, and I am calculating 3072 * 3072 3072 * 64 => 3072 * 64 matmul, and see this in nsight: volta_sgemm_64x32_sliced1x4_nn(1, 96, 1)x(256, 1, 1). Do you have any idea what sliced1x4 means? I only see this 1x4 in cutlass.....

Also, what do you think nn is?

Thank you!!!!

Arsmart123 avatar Jun 17 '22 12:06 Arsmart123

we want to reduce between different wraps?

Correct.

So we can not use wrap_sync function like shlf_down_sync? We will use atomicAdd to global memory?

We reduce between warps. We don't use atomic or global memory, but we use shared memory. The code is here.

Do you have any idea what sliced1x4 means?

Each threadblock tile will be splited into 4 parts along the K dimension. Every warp will take care of one part. They will be reduced finally in the epilogue. Here is a cutlass example. It splits K dimension into 2 parts because ThreadBlockShape::kK / WarpShape::kK = 64 / 32 = 2

Also, what do you think nn is

n means non-transposed which is column major. This is cuBlas Terminology. nn means both multiplicands are column major.

hwu36 avatar Jun 18 '22 04:06 hwu36

we want to reduce between different wraps?

Correct.

So we can not use wrap_sync function like shlf_down_sync? We will use atomicAdd to global memory?

We reduce between warps. We don't use atomic or global memory, but we use shared memory. The code is here.

Do you have any idea what sliced1x4 means?

Each threadblock tile will be splited into 4 parts along the K dimension. Every warp will take care of one part. They will be reduced finally in the epilogue. Here is a cutlass example. It splits K dimension into 2 parts because ThreadBlockShape::kK / WarpShape::kK = 64 / 32 = 2

Also, what do you think nn is

n means non-transposed which is column major. This is cuBlas Terminology. nn means both multiplicands are column major.

Yes, thank you very much! And just one smalllll question,

Each threadblock tile will be splited into 4 parts along the K dimension. Every warp will take care of one part.

so you mentioned sliced1x4 means slice 4 parts, I agree. And cublas here actually uses 256 threads, which is 8 warps, so I think here actually two warps will take care of one part~of size 32 * 8(A) and 8 * 64(B).

Thank you!!!

(By the way, according to my experience, 32 * 4 and 4 * 64 might be better? Two warps one part? I even doubt my conclusion now..) image

Arsmart123 avatar Jun 18 '22 14:06 Arsmart123

using Gemm = cutlass::gemm::device::Gemm< cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 8>, cutlass::epilogue::thread::LinearCombination< ElementOutput, 64 / cutlass::sizeof_bits<ElementOutput>::value, ElementAccumulator, ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2>;

I see this in your provided link and curious about its size and meaning here: cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<64, 32, 32>, cutlass::gemm::GemmShape<16, 8, 8>, So each threadBlock will calculate how large? Read how much from A and B?

Thank you!!!!

Arsmart123 avatar Jun 19 '22 03:06 Arsmart123

Thank you!! One last question, how we choose splitK or sliceK? They cope with a same problem (large K) using different policy?

Actually I am implementing this volta_sgemm_64x32_sliced1x4_nn(1, 96, 1)x(256, 1, 1), and one block will calculate 32 * 64, 8 warps, two warps calculate 32 * 64, each thread calculate 32 values and read 8 B's values and 4 A's values......It is very slow....I doubt it read too much value but calculate too less, previously with 64 calculate and 8 (=4A and 4B) works well....

Thank you!! I am tortured by this for months....

Arsmart123 avatar Jun 26 '22 08:06 Arsmart123

splitK is for large K and small M/N. It is used to saturate GPU.

sliceK is used to reduce the shared memory traffic when the tile size is small.

hwu36 avatar Jun 27 '22 14:06 hwu36

splitK is for large K and small M/N. It is used to saturate GPU.

sliceK is used to reduce the shared memory traffic when the tile size is small.

Could you help me explain "sliceK is used to reduce the shared memory traffic when the tile size is small", what is ”shared memory traffic“? thanks a lot.

kaiyuanm avatar Jul 21 '22 13:07 kaiyuanm

what is ”shared memory traffic“?

loading and storing from/to the shared memory.

hwu36 avatar Jul 21 '22 16:07 hwu36

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Aug 20 '22 17:08 github-actions[bot]