cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] GemmUniversal is slower than GemmSplitKParallel when M and N are small and K is large

Open hychiang-git opened this issue 1 year ago • 1 comments

Hello,

I read this issue:

  • kernel::GemmUniversal with mode GemmUniversalMode::kGemmSplitKParallel will be equivalent to kernel::GemmSplitKParallel. The difference comes to fore for the device::-scoped kernels, wherein device::GemmSplitKParallel calls a reduction kernel and device::GemmUniversal does not. However, it is recommended that you use device::GemmUniversal rather than device::GemmSplitKParallel, as the former is more-frequently tested.

Originally posted by @jackkosaian in https://github.com/NVIDIA/cutlass/issues/702#issuecomment-1331414081

However, I tested two implementations and found that GemmUniversal is much slower than GemmSplitKParallel when M and N are small and K is large, for example, M=64, N=64, K=4096.

GemmSplitKParallel: 0.011651 ms
UniversalGemmStreamK: 0.083712 ms

How could I configure the GemmUniversal to reproduce the speed of GemmSplitKParallel for M=64, N=64, K=4096? Thanks!

I profile GEMMs with Cutlass v3.4.1 on an A5000 GPU.

Here is my testing code.

#include <iostream>

#include "cuda_runtime.h"

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"

// copy from https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu
void run_UniversalGemmStreamK(int m, int n, int k, int n_iter) {

    using ElementAccumulator = float;                   // <- data type of accumulator
    using ElementComputeEpilogue = ElementAccumulator;  // <- data type of epilogue operations
    using ElementInputA = cutlass::half_t;              // <- data type of elements in input matrix A
    using ElementInputB = cutlass::half_t;              // <- data type of elements in input matrix B
    using ElementOutput = float;                        // <- data type of elements in output matrix D

    using LayoutInputA = cutlass::layout::RowMajor;
    using LayoutInputB = cutlass::layout::ColumnMajor;
    using LayoutOutput = cutlass::layout::RowMajor;

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    // Initialize tensors using CUTLASS helper functions
    cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
    cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());

    cutlass::reference::host::TensorFillRandomUniform(
        tensor_a.host_view(),
        1,
        ElementInputA(4),
        ElementInputA(-4),
        0);  // <- Fill matrix A on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_b.host_view(),
        1,
        ElementInputB(4),
        ElementInputB(-4),
        0);  // <- Fill matrix B on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_c.host_view(),
        1,
        ElementOutput(4),
        ElementOutput(-4),
        0);  // <- Fill matrix C on host with uniform-distribution random data
    cutlass::reference::host::TensorFill(
        tensor_d.host_view());  // <- fill matrix D on host with zeros

    // Copy data from host to GPU
    tensor_a.sync_device();
    tensor_b.sync_device();
    tensor_c.sync_device();
    tensor_d.sync_device();

    using MMAOp = cutlass::arch::OpClassTensorOp;
    using SmArch = cutlass::arch::Sm80;
    using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
    using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
    using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;

    using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
        ElementOutput,
        128 / cutlass::sizeof_bits<ElementOutput>::value,
        ElementAccumulator,
        ElementComputeEpilogue>;

  constexpr int NumStages  = 4; 
  // StreamK device GEMM implementation type
  using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal<
      ElementInputA, LayoutInputA,
      ElementInputB, LayoutInputB,
      ElementOutput, LayoutOutput,
      ElementAccumulator,
      MMAOp,
      SmArch,
      ShapeMMAThreadBlock,
      ShapeMMAWarp,
      ShapeMMAOp,
      EpilogueOp,
      cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference
      NumStages,
      128 / cutlass::sizeof_bits<ElementInputA>::value,
      128 / cutlass::sizeof_bits<ElementInputB>::value>;

  // Initialize alpha and beta for dot product computation
  ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
  ElementComputeEpilogue beta = ElementComputeEpilogue(0);

  typename DeviceGemmStreamK::Arguments arguments{
    cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel,  // kGemmSplitKParallel mode
    problem_size,                     // problem_size
    16,                               // batch count / splitk slices
    {alpha, beta},                     // epilogue parameters
    tensor_a.device_data(),                   // ptr_A
    tensor_b.device_data(),                   // ptr_B
    tensor_c.device_data(),                   // ptr_C
    tensor_d.device_data(),                   // ptr_D
    problem_size.mk().product(),      // batch_stride_A
    problem_size.nk().product(),      // batch_stride_B
    problem_size.mn().product(),      // batch_stride_C
    problem_size.mn().product(),      // batch_stride_D
    tensor_a.layout().stride(0),              // stride_a
    tensor_b.layout().stride(0),              // stride_b
    tensor_c.layout().stride(0),              // stride_c
    tensor_d.layout().stride(0),              // stride_d
    0};

  DeviceGemmStreamK gemm_op;
  size_t workspace_size = DeviceGemmStreamK::get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
  cudaEvent_t start, stop;
  cudaEventCreate(&start);
  cudaEventCreate(&stop);

  // warmup
  cutlass::Status status;
  status = gemm_op(arguments, workspace.get());
  CUTLASS_CHECK(status);
  // time
  cudaEventRecord(start);
  for (int i=0; i<n_iter; i++) {
    status = gemm_op(arguments, workspace.get());
    CUTLASS_CHECK(status);
  }
  cudaEventRecord(stop);
  cudaEventSynchronize(stop);
  // print
  float milliseconds = 0;
  cudaEventElapsedTime(&milliseconds, start, stop);
  printf("UniversalGemmStreamK: %f ms\n", milliseconds/n_iter);
}

// copy from https://github.com/NVIDIA/cutlass/blob/main/examples/06_splitK_gemm/splitk_gemm.cu
void run_GemmSplitKParallel(int m, int n, int k, int n_iter) {

    using ElementAccumulator = float;                   // <- data type of accumulator
    using ElementComputeEpilogue = ElementAccumulator;  // <- data type of epilogue operations
    using ElementInputA = cutlass::half_t;              // <- data type of elements in input matrix A
    using ElementInputB = cutlass::half_t;              // <- data type of elements in input matrix B
    using ElementOutput = float;                        // <- data type of elements in output matrix D

    using LayoutInputA = cutlass::layout::RowMajor;
    using LayoutInputB = cutlass::layout::ColumnMajor;
    using LayoutOutput = cutlass::layout::RowMajor;

    // Create a tuple of problem size for matrix multiplication
    cutlass::gemm::GemmCoord problem_size(m, n, k);

    // Initialize tensors using CUTLASS helper functions
    cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
    cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
    cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());

    // Fill input and output matrices on host using CUTLASS helper functions
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_a.host_view(),
        1,
        ElementInputA(4),
        ElementInputA(-4),
        0);  // <- Fill matrix A on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_b.host_view(),
        1,
        ElementInputB(4),
        ElementInputB(-4),
        0);  // <- Fill matrix B on host with uniform-distribution random data
    cutlass::reference::host::TensorFillRandomUniform(
        tensor_c.host_view(),
        1,
        ElementOutput(4),
        ElementOutput(-4),
        0);  // <- Fill matrix C on host with uniform-distribution random data
    cutlass::reference::host::TensorFill(
        tensor_d.host_view());  // <- fill matrix D on host with zeros

    // Copy data from host to GPU
    tensor_a.sync_device();
    tensor_b.sync_device();
    tensor_c.sync_device();
    tensor_d.sync_device();

    using MMAOp = cutlass::arch::OpClassTensorOp;
    using SmArch = cutlass::arch::Sm80;
    using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
    using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
    using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;

    // This code section describes ?
    using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
        ElementOutput,
        128 / cutlass::sizeof_bits<ElementOutput>::value,
        ElementAccumulator,
        ElementComputeEpilogue>;

    using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
                                                        LayoutInputA,
                                                        ElementInputB,
                                                        LayoutInputB,
                                                        ElementOutput,
                                                        LayoutOutput,
                                                        ElementAccumulator,
                                                        MMAOp,
                                                        SmArch,
                                                        ShapeMMAThreadBlock,
                                                        ShapeMMAWarp,
                                                        ShapeMMAOp,
                                                        EpilogueOp>;

  ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
  ElementComputeEpilogue beta = ElementComputeEpilogue(0);

  // split K dimension into 16 partitions
  int split_k_slices = 16;
  typename Gemm::Arguments arguments{problem_size,  // <- problem size of matrix multiplication
                                     tensor_a.device_ref(),  // <- reference to matrix A on device
                                     tensor_b.device_ref(),  // <- reference to matrix B on device
                                     tensor_c.device_ref(),  // <- reference to matrix C on device
                                     tensor_d.device_ref(),  // <- reference to matrix D on device
                                     {alpha, beta},          // <- tuple of alpha and beta
                                     split_k_slices};        // <- k-dimension split factor
  Gemm gemm_op;
  size_t workspace_size = Gemm::get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
  cudaEvent_t start, stop;
  cudaEventCreate(&start);
  cudaEventCreate(&stop);

  // warmup
  cutlass::Status status;
  status = gemm_op(arguments, workspace.get());
  CUTLASS_CHECK(status);
  // time
  cudaEventRecord(start);
  for (int i=0; i<n_iter; i++) {
    status = gemm_op(arguments, workspace.get());
    CUTLASS_CHECK(status);
  }
  cudaEventRecord(stop);
  cudaEventSynchronize(stop);
  // print
  float milliseconds = 0;
  cudaEventElapsedTime(&milliseconds, start, stop);
  printf("GemmSplitKParallel: %f ms\n", milliseconds/n_iter);
}


int main() {

  cudaDeviceProp props;

  cudaError_t error = cudaGetDeviceProperties(&props, 0);
  if (error != cudaSuccess) {
    std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
    return -1;
  }

  // Define problem size
  const int length_m = 64;
  const int length_n = 64;
  const int length_k = 4096;
  const int n_iter = 100;
  run_GemmSplitKParallel(length_m, length_n, length_k, n_iter);
  run_UniversalGemmStreamK(length_m, length_n, length_k, n_iter);
  return 0;
}


hychiang-git avatar Jun 12 '24 23:06 hychiang-git

Hello, have there been any updates on this issue? We very much appreciate your help 🙏 @thakkarV

nfrumkin avatar Jun 27 '24 00:06 nfrumkin

the one blessed way of comparing performance of CUTLASS kernels is via the profiler. have you tried running your splitK kernel and compare its performance against what the CUTLASS profiler reports?

How could I configure the GemmUniversal to reproduce the speed of GemmSplitKParallel for M=64, N=64, K=4096? Thanks!

I do not know in what ways the two implementations differ but it could be a ton of things. Splitting factor, separate versus fused serial reductions versus fused parallel reductions, the threadblock rasterization policy etc

Have you made sure the various tile sizes and other template configurations are the same between universal and split K specific configs?

thakkarV avatar Jul 02 '24 11:07 thakkarV

@jackkosaian

thakkarV avatar Jul 02 '24 11:07 thakkarV

In addition to what @thakkarV suggests, I see that the GemmUniversal kernel you're using uses stream-K (cutlass::gemm::threadblock::ThreadblockSwizzleStreamK), but you're passing in the mode as cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel.

This is not expected to work.

To use cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, one should use a non-stream-K swizzling functor (e.g., cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>)

jackkosaian avatar Jul 02 '24 12:07 jackkosaian

@jackkosaian Thanks! cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> works! It boost the performance of UniversalGemmStreamK by around 10x (0.083712 ms -> 0.007076 ms). Here is the result:

GemmSplitKParallel: 0.010189 ms
UniversalGemmStreamK: 0.007076 ms

@thakkarV Is there a document that I could follow to integrate the CUTLASS profiler in my testing code? Thanks!

hychiang-git avatar Jul 02 '24 17:07 hychiang-git

https://github.com/NVIDIA/cutlass/blob/main/media/docs/profiler.md

thakkarV avatar Jul 02 '24 17:07 thakkarV

Also, your version that uses GemmUniversal will need to perform a second reduction kernel after calling the GEMM in order to reduce the partial outputs. You can see an example of doing this here.

jackkosaian avatar Jul 02 '24 17:07 jackkosaian

Thanks for pointing out! I will try that!

hychiang-git avatar Jul 02 '24 17:07 hychiang-git