cutlass
cutlass copied to clipboard
[QST] GemmUniversal is slower than GemmSplitKParallel when M and N are small and K is large
Hello,
I read this issue:
kernel::GemmUniversalwith modeGemmUniversalMode::kGemmSplitKParallelwill be equivalent tokernel::GemmSplitKParallel. The difference comes to fore for thedevice::-scoped kernels, whereindevice::GemmSplitKParallelcalls a reduction kernel anddevice::GemmUniversaldoes not. However, it is recommended that you usedevice::GemmUniversalrather thandevice::GemmSplitKParallel, as the former is more-frequently tested.
Originally posted by @jackkosaian in https://github.com/NVIDIA/cutlass/issues/702#issuecomment-1331414081
However, I tested two implementations and found that GemmUniversal is much slower than GemmSplitKParallel when M and N are small and K is large, for example, M=64, N=64, K=4096.
GemmSplitKParallel: 0.011651 ms
UniversalGemmStreamK: 0.083712 ms
How could I configure the GemmUniversal to reproduce the speed of GemmSplitKParallel for M=64, N=64, K=4096? Thanks!
I profile GEMMs with Cutlass v3.4.1 on an A5000 GPU.
Here is my testing code.
#include <iostream>
#include "cuda_runtime.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// copy from https://github.com/NVIDIA/cutlass/blob/main/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu
void run_UniversalGemmStreamK(int m, int n, int k, int n_iter) {
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(m, n, k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm80;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
constexpr int NumStages = 4;
// StreamK device GEMM implementation type
using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference
NumStages,
128 / cutlass::sizeof_bits<ElementInputA>::value,
128 / cutlass::sizeof_bits<ElementInputB>::value>;
// Initialize alpha and beta for dot product computation
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
typename DeviceGemmStreamK::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // kGemmSplitKParallel mode
problem_size, // problem_size
16, // batch count / splitk slices
{alpha, beta}, // epilogue parameters
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
tensor_c.device_data(), // ptr_C
tensor_d.device_data(), // ptr_D
problem_size.mk().product(), // batch_stride_A
problem_size.nk().product(), // batch_stride_B
problem_size.mn().product(), // batch_stride_C
problem_size.mn().product(), // batch_stride_D
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
tensor_c.layout().stride(0), // stride_c
tensor_d.layout().stride(0), // stride_d
0};
DeviceGemmStreamK gemm_op;
size_t workspace_size = DeviceGemmStreamK::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
// warmup
cutlass::Status status;
status = gemm_op(arguments, workspace.get());
CUTLASS_CHECK(status);
// time
cudaEventRecord(start);
for (int i=0; i<n_iter; i++) {
status = gemm_op(arguments, workspace.get());
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaEventSynchronize(stop);
// print
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
printf("UniversalGemmStreamK: %f ms\n", milliseconds/n_iter);
}
// copy from https://github.com/NVIDIA/cutlass/blob/main/examples/06_splitK_gemm/splitk_gemm.cu
void run_GemmSplitKParallel(int m, int n, int k, int n_iter) {
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// Create a tuple of problem size for matrix multiplication
cutlass::gemm::GemmCoord problem_size(m, n, k);
// Initialize tensors using CUTLASS helper functions
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(problem_size.mn());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());
// Fill input and output matrices on host using CUTLASS helper functions
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(4),
ElementInputA(-4),
0); // <- Fill matrix A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(4),
ElementInputB(-4),
0); // <- Fill matrix B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(4),
ElementOutput(-4),
0); // <- Fill matrix C on host with uniform-distribution random data
cutlass::reference::host::TensorFill(
tensor_d.host_view()); // <- fill matrix D on host with zeros
// Copy data from host to GPU
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm80;
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>;
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
using Gemm = cutlass::gemm::device::GemmSplitKParallel<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp>;
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
// split K dimension into 16 partitions
int split_k_slices = 16;
typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication
tensor_a.device_ref(), // <- reference to matrix A on device
tensor_b.device_ref(), // <- reference to matrix B on device
tensor_c.device_ref(), // <- reference to matrix C on device
tensor_d.device_ref(), // <- reference to matrix D on device
{alpha, beta}, // <- tuple of alpha and beta
split_k_slices}; // <- k-dimension split factor
Gemm gemm_op;
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
// warmup
cutlass::Status status;
status = gemm_op(arguments, workspace.get());
CUTLASS_CHECK(status);
// time
cudaEventRecord(start);
for (int i=0; i<n_iter; i++) {
status = gemm_op(arguments, workspace.get());
CUTLASS_CHECK(status);
}
cudaEventRecord(stop);
cudaEventSynchronize(stop);
// print
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
printf("GemmSplitKParallel: %f ms\n", milliseconds/n_iter);
}
int main() {
cudaDeviceProp props;
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return -1;
}
// Define problem size
const int length_m = 64;
const int length_n = 64;
const int length_k = 4096;
const int n_iter = 100;
run_GemmSplitKParallel(length_m, length_n, length_k, n_iter);
run_UniversalGemmStreamK(length_m, length_n, length_k, n_iter);
return 0;
}
Hello, have there been any updates on this issue? We very much appreciate your help 🙏 @thakkarV
the one blessed way of comparing performance of CUTLASS kernels is via the profiler. have you tried running your splitK kernel and compare its performance against what the CUTLASS profiler reports?
How could I configure the GemmUniversal to reproduce the speed of GemmSplitKParallel for M=64, N=64, K=4096? Thanks!
I do not know in what ways the two implementations differ but it could be a ton of things. Splitting factor, separate versus fused serial reductions versus fused parallel reductions, the threadblock rasterization policy etc
Have you made sure the various tile sizes and other template configurations are the same between universal and split K specific configs?
@jackkosaian
In addition to what @thakkarV suggests, I see that the GemmUniversal kernel you're using uses stream-K (cutlass::gemm::threadblock::ThreadblockSwizzleStreamK), but you're passing in the mode as cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel.
This is not expected to work.
To use cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, one should use a non-stream-K swizzling functor (e.g., cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>)
@jackkosaian Thanks! cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> works! It boost the performance of UniversalGemmStreamK by around 10x (0.083712 ms -> 0.007076 ms). Here is the result:
GemmSplitKParallel: 0.010189 ms
UniversalGemmStreamK: 0.007076 ms
@thakkarV Is there a document that I could follow to integrate the CUTLASS profiler in my testing code? Thanks!
https://github.com/NVIDIA/cutlass/blob/main/media/docs/profiler.md
Also, your version that uses GemmUniversal will need to perform a second reduction kernel after calling the GEMM in order to reduce the partial outputs. You can see an example of doing this here.
Thanks for pointing out! I will try that!