cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] How many threads and blocks does cutlass use? (When C is tall in official post)

Open Arsmart123 opened this issue 3 years ago • 4 comments

Hi! I am learning cutlass. And I read this post: CUTLASS: Fast Linear Algebra in CUDA C++ | NVIDIA Technical Blog But I can not find official “dispatch_policies.h”, only find one in huggingface’s github! pytorch_block_sparse/dispatch_policies.h at master · huggingface/pytorch_block_sparse · GitHub

Actually I am developing a kernel related to “tall” matmul, which has small N and large M for resulting C(m*n). So I am quite interested in the parameters here:(as shown in the post)

/// CUTLASS SGEMM example
__global__ void gemm_kernel(void gemm_kernel(
    float *C, float *C, 
    float const *A, float const *A, 
    float const *B, float const *B, 
    int M, int M, 
    int N, int N, 
    int K) {int K) {

    // Define the GEMM tile sizes - discussed in next section// Define the GEMM tile sizes - discussed in next section
    typedef block_task_policy <typedef block_task_policy <
        128, // BlockItemsY: Height in rows of a tile128, // BlockItemsY: Height in rows of a tile
        32, // BlockItemsX - Width in columns of a tile32, // BlockItemsX - Width in columns of a tile
        8, // ThreadItemsY - Height in rows of a thread-tile8, // ThreadItemsY - Height in rows of a thread-tile
        4, // ThreadItemsX - Width in columns of a thread-tile4, // ThreadItemsX - Width in columns of a thread-tile
        8, // BlockItemsK - Depth of a tile8, // BlockItemsK - Depth of a tile
        true, // UseDoubleScratchTiles - whether to double-buffer SMEMtrue, // UseDoubleScratchTiles - whether to double-buffer SMEM
        block_raster_enum::Default // Block rasterization strategy::Default // Block rasterization strategy
    > block_task_policy_t;> block_task_policy_t;

    // Define the epilogue functor// Define the epilogue functor
    typedef gemm::blas_scaled_epilogue<float, float, float> epilogue_op_t ;typedef gemm::blas_scaled_epilogue<float, float, float> epilogue_op_t ;

    // Define the block_task type.// Define the block_task type.
    typedef block_task < typedef block_task < 
        block_task_policy_t, block_task_policy_t, 
        float, float, 
        float, float, 
        matrix_transform_t::NonTranspose, matrix_transform_t::NonTranspose, 
        4, 4, 
        matrix_transform_t::NonTranspose, matrix_transform_t::NonTranspose, 
        4, 4, 
        epilogue_op_t, epilogue_op_t, 
        4, 4, 
        true true 
    > block_task_t;> block_task_t;

    // Declare statically-allocated shared storage// Declare statically-allocated shared storage
    __shared__ block_task_t::scratch_storage_t smem;block_task_t::scratch_storage_t smem;

    // Construct and run the task// Construct and run the task
    block_task_t(block_task_t(
        reinterpret_cast(&smem),reinterpret_cast(&smem),
        &smem,&smem,
        A,,
        B,,
        C,,
        epilogue_op_t(1, 0),epilogue_op_t(1, 0),
        M,,
        N,,
        K).run();).run();
}}

My question is: how many thread and block are allocated here? I really can not find any info nearby…

Thank you!!!

Arsmart123 avatar Jun 15 '22 03:06 Arsmart123

The code you posted belongs to cutlass 0.1. The current cutlass looks very different. Here is how the top level looks like if you use tensor cores: https://github.com/NVIDIA/cutlass/blob/master/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu#L212-L226

thread number is 32 * ShapeMMAThreadBlock::kM / ShapeMMAWarp::kM * ShapeMMAThreadBlock::kN / ShapeMMAWarp::kN * ShapeMMAThreadBlock::kK / ShapeMMAWarp::kK

threadblock number is problem_size_M / ShapeMMAThreadBlock::kM * problem_size_N / ShapeMMAThreadBlock::kN * problem_size_K / ShapeMMAThreadBlock::kK

hwu36 avatar Jun 15 '22 14:06 hwu36

Oh, thank you!!

  1. Actually I am not using tensor core, if you have a version without it, better~I am using 1650 (turing, sm75) actually.

  2. Also, I notice here we use 32128/64128/64*64/64=128 threads per block. I did not find a version specifically for "tall" matrix (which means, then resulting C has a large M but small N). In official post mentioned different policies for different shape. Do you know any dividing policy in the latest cutlass?

Thank you!!!

Arsmart123 avatar Jun 17 '22 03:06 Arsmart123

Different problem size needs different tile sizes. You can use cutlass profiler to find it. Here is the doc: https://github.com/NVIDIA/cutlass/blob/master/media/docs/profiler.md

You can use

cmake .. -DCUTLASS_NVCC_ARCHS="75" -DCUTLASS_LIBRARY_KERNELS=sgemm

to only generate sm75 fp32 non tensor core gemm.

hwu36 avatar Jun 17 '22 15:06 hwu36

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Jul 17 '22 16:07 github-actions[bot]