[QST/BUG] why cute kernel transfers so much data between L2 and gmen than cublas kernel
What is your question?
I am learning to use cute to build a hgemm kernel. Tested on A10 GPU, the cute kernel is good with small problem size such as m/n/k = 4096, but I found it's much slower than cublas kernel with problem size m/n/k=16384/16384/16384 as below
Here is the profile result from ncu for problem m/n/k=16384/16384/16384.
And I found the biggest difference between my cute kernel and cublas kernel is the memory chart
cublas kernel
my cute kernel
I was wondering why my cute kernel has so much gmem->L2 and L2->shared data movement compared to cublas kernel. And how should I modify the cute kernel to improve performance for big problem size.
Here is my cute kernel
namespace config {
using namespace cute;
template <int BM, int BN, int BK, int Stage, int SWIZZLE>
struct HGemmConfig {
using ADataType = half_t;
using BDataType = half_t;
using CDataType = half_t;
using TileShape = Shape<Int<BM>, Int<BN>, Int<BK>>;
using TiledMma = TiledMMA<MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>, Layout<Shape<_2, _2, _1>>, Tile<_32, _32, _16>>;
static constexpr int ThreadCount = size(TiledMma{});
// A: row-major (m, k)
using SmemLayoutAtomA = decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, Shape<Int<BM>, Int<BK>, Int<Stage>>{}));
using G2STiledCopyA = decltype(make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, ADataType>{}, Layout<Shape<_32, _4>, Stride<_4, _1>>{}, Layout<Shape<_1, _8>>{}));
using S2RCopyAtomA = Copy_Atom<SM75_U32x4_LDSM_N, ADataType>;
using S2RTiledCopyA = decltype(make_tiled_copy_A(S2RCopyAtomA{}, TiledMma{}));
// B: row-major (k, n)
using SmemLayoutAtomB = decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_32, _8>, Stride<_1, _32>>{}));
using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, Shape<Int<BN>, Int<BK>, Int<Stage>>{}));
using G2STiledCopyB = decltype(make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, BDataType>{}, Layout<Shape<_16, _8>, Stride<_1, _16>>{}, Layout<Shape<_8, _1>>{}));
using S2RCopyAtomB = Copy_Atom<SM75_U16x8_LDSM_T, BDataType>;
using S2RTiledCopyB = decltype(make_tiled_copy_B(S2RCopyAtomB{}, TiledMma{}));
// C: row-major (m, n)
using SmemLayoutAtomC = decltype(composition(Swizzle<2, 3, 3>{}, Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, Shape<Int<BM>, Int<BN>>{}));
using R2SCopyAtomC = Copy_Atom<UniversalCopy<int>, CDataType>;
using R2STiledCopyC = decltype(make_tiled_copy_C(R2SCopyAtomC{}, TiledMma{}));
using S2GTiledCopyC = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint128_t>, CDataType>{}, Layout<Shape<_32, _4>, Stride<_4, _1>>{}, Layout<Shape<_1, _8>>{}));
static constexpr int SmemSizeA = cosize_v<SmemLayoutA> * sizeof(ADataType);
static constexpr int SmemSizeB = cosize_v<SmemLayoutB> * sizeof(BDataType);
static constexpr int SmemSizeC = cosize_v<SmemLayoutC> * sizeof(CDataType);
static constexpr int SmemSize = cute::max(SmemSizeA + SmemSizeB, SmemSizeC);
static constexpr int kStage = Stage;
static constexpr int kSwizzle = SWIZZLE;
};
} // namespace config
template <typename Config>
__global__ void CUTEMultiStageKernel(const half* A, const half* B, half* C, int M, int N, int K) {
using namespace cute;
constexpr int SWIZZLE = Config::kSwizzle;
int tid = threadIdx.x;
int bx = blockIdx.x / SWIZZLE;
int by = blockIdx.y * SWIZZLE + blockIdx.x % SWIZZLE;
Tensor mA = make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(K, Int<1>{})); // (M,K):(K,1)
Tensor mB = make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(Int<1>{}, N)); // (N,K):(1,N)
Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(N, Int<1>{})); // (M,N):(N,1)
auto cta_tiler = typename Config::TileShape{}; // (BM, BN, BK)
auto cta_coord = make_coord(by, bx, _); // (m, n, k)
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); // (BM, BK, k_tile)
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step<X, _1, _1>{}); // (BN, BK, k_tile)
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); // (BM, BN)
extern __shared__ uint8_t raw_smem[];
half* p_sA = reinterpret_cast<half*>(raw_smem);
half* p_sB = reinterpret_cast<half*>(raw_smem + Config::SmemSizeA);
Tensor sA = make_tensor(make_smem_ptr(p_sA), typename Config::SmemLayoutA{}); // (BM, BK, Stage)
Tensor sB = make_tensor(make_smem_ptr(p_sB), typename Config::SmemLayoutB{}); // (BN, BK, Stage)
typename Config::TiledMma tiled_mma{};
ThrMMA thr_mma = tiled_mma.get_slice(tid);
Tensor tCrA = thr_mma.partition_fragment_A(gA(_, _, 0));
Tensor tCrB = thr_mma.partition_fragment_B(gB(_, _, 0));
Tensor tCrC = thr_mma.partition_fragment_C(gC);
clear(tCrC);
typename Config::G2STiledCopyA tiled_g2s_A{};
ThrCopy thr_g2s_A = tiled_g2s_A.get_slice(tid);
Tensor tAgA = thr_g2s_A.partition_S(gA);
Tensor tAsA = thr_g2s_A.partition_D(sA);
typename Config::G2STiledCopyB tiled_g2s_B{};
ThrCopy thr_g2s_B = tiled_g2s_B.get_slice(tid);
Tensor tBgB = thr_g2s_B.partition_S(gB);
Tensor tBsB = thr_g2s_B.partition_D(sB);
typename Config::S2RTiledCopyA tiled_s2r_A{};
ThrCopy thr_s2r_A = tiled_s2r_A.get_slice(tid);
Tensor tCsA = thr_s2r_A.partition_S(sA);
Tensor tCrA_retile = thr_s2r_A.retile_D(tCrA);
typename Config::S2RTiledCopyB tiled_s2r_B{};
ThrCopy thr_s2r_B = tiled_s2r_B.get_slice(tid);
Tensor tCsB = thr_s2r_B.partition_S(sB);
Tensor tCrB_retile = thr_s2r_B.retile_D(tCrB);
constexpr int STAGE = Config::kStage;
for (int i = 0; i < STAGE - 1; ++i) {
cute::copy(tiled_g2s_A, tAgA(_, _, _, i), tAsA(_, _, _, i));
cute::copy(tiled_g2s_B, tBgB(_, _, _, i), tBsB(_, _, _, i));
cp_async_fence();
}
int k_tile_count = size<3>(tAgA);
int k_tile_next = STAGE - 1;
int smem_pipe_write = STAGE - 1;
int smem_pipe_read = 0;
cp_async_wait<STAGE - 2>();
__syncthreads();
int ik = 0;
cute::copy(tiled_s2r_A, tCsA(_, _, ik, smem_pipe_read), tCrA_retile(_, _, ik));
cute::copy(tiled_s2r_B, tCsB(_, _, ik, smem_pipe_read), tCrB_retile(_, _, ik));
constexpr int CHUNK_K = size<2>(tCrA);
for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) {
for (int ik = 0; ik < CHUNK_K; ++ik) {
if (ik == 0) {
if (k_tile_next < k_tile_count) {
cute::copy(tiled_g2s_A, tAgA(_, _, _, k_tile_next), tAsA(_, _, _, smem_pipe_write));
cute::copy(tiled_g2s_B, tBgB(_, _, _, k_tile_next), tBsB(_, _, _, smem_pipe_write));
++k_tile_next;
smem_pipe_write = (smem_pipe_write + 1) % STAGE;
}
cp_async_fence();
}
cute::gemm(tiled_mma, tCrA(_, _, ik), tCrB(_, _, ik), tCrC);
if (ik == CHUNK_K - 1) {
cp_async_wait<STAGE - 2>();
__syncthreads();
smem_pipe_read = (smem_pipe_read + 1) % STAGE;
}
int ik_next = (ik + 1) % CHUNK_K;
cute::copy(tiled_s2r_A, tCsA(_, _, ik_next, smem_pipe_read), tCrA_retile(_, _, ik_next));
cute::copy(tiled_s2r_B, tCsB(_, _, ik_next, smem_pipe_read), tCrB_retile(_, _, ik_next));
}
}
half* p_sC = reinterpret_cast<half*>(raw_smem);
Tensor sC = make_tensor(make_smem_ptr(p_sC), typename Config::SmemLayoutC{});
typename Config::R2STiledCopyC tiled_r2s_C{};
ThrCopy thr_r2s_C = tiled_r2s_C.get_slice(tid);
Tensor tCrC_retile = thr_r2s_C.retile_S(tCrC);
Tensor tCsC = thr_r2s_C.partition_D(sC);
__syncthreads();
cute::copy(tiled_r2s_C, tCrC_retile, tCsC);
typename Config::S2GTiledCopyC tiled_s2g_C{};
ThrCopy thr_s2g_C = tiled_s2g_C.get_slice(tid);
Tensor tDsC = thr_s2g_C.partition_S(sC);
Tensor tDgC = thr_s2g_C.partition_D(gC);
__syncthreads();
cute::copy(tiled_s2g_C, tDsC, tDgC);
}
void CUTEMultiStage(half* A, half* B, half* C, int M, int N, int K) {
constexpr int BM = 128;
constexpr int BN = 128;
constexpr int BK = 32;
constexpr int STAGE = 3;
constexpr int SWIZZLE = 4;
using hgemm_config = config::HGemmConfig<BM, BN, BK, STAGE, SWIZZLE>;
constexpr int smem_max_size = hgemm_config::SmemSize;
static bool initialized = false;
if (!initialized) {
PD_CUDA_CHECK(cudaFuncSetAttribute(CUTEMultiStageKernel<hgemm_config>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_max_size));
initialized = true;
}
dim3 block(hgemm_config::ThreadCount);
dim3 grid(PD_ROUND_DIV(N, BN) * SWIZZLE, PD_ROUND_DIV(PD_ROUND_DIV(M, BM), SWIZZLE));
CUTEMultiStageKernel<hgemm_config><<<grid, block, smem_max_size>>>(A, B, C, M, N, K);
}
you are likely thrashing the L2 locality and by not doing any block ID remapping / swizzling
Hi, @thakkarV, thanks for your reply The performance is still bad when I removed the thread block swizzle.
I guess this abnormal data reading must be related to the cp.async of g2s part , but I don't know the specific reason of the problem.
And I found if I change
constexpr int BN = 128;
to
constexpr int BN = 256;
, the abnormal data reading will disappear, and here's the result of ncu.
The data reading from gmem to L2 cache is still more than cublas version, but I think it should be acceptable here.
And the throughput result of different large input sizes looks good too.
However, I have no idea how the block tile shape here will affect the L2 cache locality. Is there something wrong in my code or some bugs here? can anyone help?
By "block ID remapping / swizzling", Vijay meant the Tile Schedulers that are part of CUTLASS and cuBLAS, not the swizzling on the data or threadblocks. CUTLASS and cuBLAS have many Tile Schedulers including Split-K and Stream-K and other skews on the block ID assignment to work tile. Many of these strategies are targeted at increasing L2 locality. So cuBLAS uses a heuristic to choose the best Tile Scheduler for your problem, which certainly changes with problem size.
For more information, I recommend the (full version of our) Stream-K paper: https://arxiv.org/pdf/2301.03598
Hi @ccecka can I ask a basic question? I skimmed through the stream-K paper and my (rudimentary) understanding is that it beats cublas by minimizing the tail effect better. In this situation where the GEMM size is large I am assuming there are lots of threadblocks so would we expect the performance of cuBLAS be more or less the same as stream-K?
In other words does streak-K outperform cuBLAS for GEMM with a large number of blocks?
Thanks!
Any update? Can I think it's the expected behavior for this cute kernel?
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
@ssiu
In other words does streak-K outperform cuBLAS for GEMM with a large number of blocks?
cuBLAS using Steam-K
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.