cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST]Is this the complete set of valid parameters for performing fp16 matrix multiplication using tensor cores?

Open zwshan opened this issue 1 year ago • 12 comments

What is your question? In this website, there are many parameters, but may I ask if the parameters listed on this page are already all the valid ones?

zwshan avatar Jan 16 '24 07:01 zwshan

No, they are not going to be a full set of parameters the API supports. Generally speaking, the set of all valid template parameters supported by any kernel is so huge due to combinatorial explosion that no amount of testing and cutlass library can generate all valid parameters.

thakkarV avatar Jan 16 '24 16:01 thakkarV

No, they are not going to be a full set of parameters the API supports. Generally speaking, the set of all valid template parameters supported by any kernel is so huge due to combinatorial explosion that no amount of testing and cutlass library can generate all valid parameters.

I have noticed that when performing matrix multiplication on the A100 machine, the computation speed for dimensions MNK set to 1024,150,256 and MNK set to 1024,1,256 is significantly slower compared to cublas. I have tried all the parameters listed on the following website, but I still can't match or exceed the performance of cublas. What should I do now?

zwshan avatar Jan 17 '24 06:01 zwshan

What needs to be added is that MNK means matrix multiplication of (M, K) * (N, K).

zwshan avatar Jan 17 '24 13:01 zwshan

Can you help me, please?@hwu36

zwshan avatar Jan 17 '24 13:01 zwshan

@zwshan There are no expectations that CUTLASS should match or exceed cuBLAS performance. The intent of CUTLASS is to provide developers with an additional tool to cuBLAS to explore functionality and requirements not currently supported by our libraries.

mnicely avatar Jan 17 '24 16:01 mnicely

Also, you should not be using a GEMM kernel for a GEMV problem. We have a GEMV and batched GEMV implementation that are better suited for your problem shapes.

thakkarV avatar Jan 17 '24 18:01 thakkarV

you could use nsight or nvprof to get the kernel name used by cublas. the kernel name has the information of the tile sizes used. then we can fine tune cutlass from the same tile sizes used by cublas.

hwu36 avatar Jan 17 '24 19:01 hwu36

thank you all! I will try it now!

zwshan avatar Jan 18 '24 05:01 zwshan

Could you please tell me how to use gemv kernel in sm80 A100 device?

zwshan avatar Jan 25 '24 07:01 zwshan

@thakkarV @hwu36

you could use nsight or nvprof to get the kernel name used by cublas. the kernel name has the information of the tile sizes used. then we can fine tune cutlass from the same tile sizes used by cublas.

I profile it and find cublas use a gemv kernel

zwshan avatar Jan 25 '24 07:01 zwshan

I want to use gemv kernel like this way

  using ElementOutput = float;
  using ElementAccumulator = float;
  using ElementComputeEpilogue = ElementAccumulator;
  using RowMajor = cutlass::layout::RowMajor;
  using ColumnMajor = cutlass::layout::ColumnMajor;
  using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
      ElementOutput,                                    // <- data type of output matrix
      128 / cutlass::sizeof_bits<ElementOutput>::value, // <- This is the number of elements per vectorized memory access. For half precision, it's 8 elements. This becomes the vector width of math instructions in epilogue too
      ElementAccumulator,                               // <- data type of accumulator
      ElementComputeEpilogue>;                          // <- data type for alpha/beta in linear combination function
      using CutlassGemm1 = cutlass::gemm::device::Gemm<
  

      cutlass::tfloat32_t,                          // Data-type of A matrix
      RowMajor,                       // Layout of A matrix
      cutlass::tfloat32_t,                          // Data-type of B matrix
      ColumnMajor,                    // Layout of B matrix
      ElementOutput,                  // Data-type of C matrix
      ColumnMajor,                    // Layout of C matrix , LayoutC = layout::ColumnMajor;
      ElementAccumulator,             // ElementAccumulator
      cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores
      cutlass::arch::Sm80,            // tag indicating target GPU compute architecture
      cutlass::gemm::GemmShape<64, 64, 32>,
      cutlass::gemm::GemmShape<32, 32, 32>,
      cutlass::gemm::GemmShape<16, 8, 8>,
      cutlass::epilogue::thread::LinearCombination<
      ElementOutput,
      128 / cutlass::sizeof_bits<ElementOutput>::value,
      ElementAccumulator,
      ElementComputeEpilogue
    >,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
    6
  >; 

  CutlassGemm1 gemm_operator;

zwshan avatar Jan 25 '24 07:01 zwshan

here is cutlass gemv example: https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/gemv.cu

the code entrance is https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/device/gemv.h

hwu36 avatar Jan 25 '24 15:01 hwu36

@zwshan has your issue been resolved?

mnicely avatar Feb 22 '24 15:02 mnicely

solved thank you

Matthew Nicely @.***>于2024年2月22日 周四23:11写道:

@zwshan https://github.com/zwshan has your issue been resolved?

— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/cutlass/issues/1304#issuecomment-1959654436, or unsubscribe https://github.com/notifications/unsubscribe-auth/AO2L3ITRU6NDFGCQOCCGCMTYU5NZBAVCNFSM6AAAAABB4JKZF6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNJZGY2TINBTGY . You are receiving this because you were mentioned.Message ID: @.***>

zwshan avatar Feb 22 '24 17:02 zwshan