cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[QST] How to use slicedK in GEMM?

Open Arsmart123 opened this issue 3 years ago • 3 comments

Hi! I have written a code for slicedK in GEMM, but it seems very slow....I tried to understand cutlass's slicedK, but can not understand it....So I post my code here and explain my concept, hoping someone can give me some suggestions....Thank you!!!

So the basic idea comes from cutlass's post. My allocation is: 64 threads (2 warps) to calculate 64 * 64 size. each thread read 4 values from A and 4 values from B. And calculate 64 results. We arrange 256 threads(8 waprs) in total, so this is sliced-4. 4 groups of warps calculate a same 64 * 64 location, but the K is split, so the final results need to be reduced.

In shared memory, I insert 1/4 of the 64 results in each threads, which is 32 * 32 one time, (take up 16KB), and repeat this 4 times. Each time we add the corresponding location's value and get 4KB's result, write into global memory.

The nsight compute shows: image

Thank you!!!!!

(By the way, where is the code for detailed cutlass's sliceK...? I am completely lossed inside....)

#include<iostream>
using namespace std;
#include <cstdint>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <vector>

#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])


bool check(const float *A,
	const float *B,
	const float *C,
	int m, int n, int k) {
	for (int i = 0; i < m; ++i) {
		for (int j = 0; j < n; ++j) {
			float sum = 0.f;
			for (int p = 0; p < k; ++p) {
				sum += A[i * k + p] * B[j + p * n];
			}

			if (std::fabs(sum - C[i * n + j]) / std::fabs(sum) > 1e-5f) {
				printf("C[%d][%d] not match, %f vs %f\n", i, j, sum, C[i * n + j]);
				return false;
			}
		}
	}

	return true;
}


__device__ __forceinline__
uint32_t smem_u32addr(const void *smem_ptr) {
	uint32_t addr;
	asm("{.reg .u64 u64addr;\n"
		" cvta.to.shared.u64 u64addr, %1;\n"
		" cvt.u32.u64 %0, u64addr;}\n"
		: "=r"(addr)
		: "l"(smem_ptr)
	);

	return addr;
}

__device__ __forceinline__
void ldg32_nc(float &reg, const void *ptr, bool guard) {
	asm volatile (
		"{.reg .pred p;\n"
		" setp.ne.b32 p, %2, 0;\n"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
    __CUDA_ARCH__ >= 750
		" @p ld.global.nc.L2::128B.f32 %0, [%1];}\n"
#else
		" @p ld.global.nc.f32 %0, [%1];}\n"
#endif
		: "=f"(reg)
		: "l"(ptr), "r"((int)guard)
		);
}


__device__ __forceinline__ void ldg32_nc_0(float &reg, const void *ptr) {
	asm volatile("{mov.b32 %0, 0;\n"
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 &&                 \
    __CUDA_ARCH__ >= 750
		"ld.global.nc.L2::128B.f32 %0, [%1];}\n"
#else
		"ld.global.nc.f32 %0, [%1];}\n"
#endif
		: "=f"(reg)
		: "l"(ptr));
}


__device__ __forceinline__
void stg32(const float &reg, void *ptr, bool guard) {
	asm volatile (
		"{.reg .pred p;\n"
		" setp.ne.b32 p, %2, 0;\n"
		" @p st.global.f32 [%0], %1;}\n"
		: : "l"(ptr), "f"(reg), "r"((int)guard)
		);
}

__device__ __forceinline__
void lds128(float &reg0, float &reg1,
	float &reg2, float &reg3,
	const uint32_t &addr) {
	asm volatile (
		"ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];\n"
		: "=f"(reg0), "=f"(reg1), "=f"(reg2), "=f"(reg3)
		: "r"(addr)
		);
}

__device__ __forceinline__
void sts32(const float &reg, const uint32_t &addr) {
	asm volatile (
		"st.shared.f32 [%0], %1;\n"
		: : "r"(addr), "f"(reg)
		);
}

__device__ __forceinline__
void sts128(const float &reg0, const float &reg1,
	const float &reg2, const float &reg3,
	const uint32_t &addr) {
	asm volatile (
		"st.shared.v4.f32 [%0], {%1, %2, %3, %4};\n"
		: : "r"(addr), "f"(reg0), "f"(reg1), "f"(reg2), "f"(reg3)
		);
}


__device__ __forceinline__
void sts64(const float &reg0, const float &reg1,
	const uint32_t &addr) {
	asm volatile (
		"st.shared.v2.f32 [%0], {%1, %2};\n"
		: : "r"(addr), "f"(reg0), "f"(reg1)
		);
}


struct StgFrag {
	float data[4][4];

	__device__ __forceinline__
		StgFrag(const float(&C_frag)[8][8], int tile_x, int tile_y) {
#pragma unroll
		for (int i = 0; i < 4; ++i) {
#pragma unroll
			for (int j = 0; j < 4; ++j) {
				data[i][j] = C_frag[tile_y * 4 + i][tile_x * 4 + j];
			}
		}
	}
};

__device__ __noinline__
void C_tile_wb(StgFrag C_frag,
	float *C_stg_ptr,
	const float *C_lds_ptr,
	uint32_t C_sts_addr,
	uint32_t m,
	uint32_t n,
	uint32_t m_idx,
	uint32_t n_idx) {
	__syncthreads();

#pragma unroll
	for (int i = 0; i < 4; ++i) {
		sts128(C_frag.data[i][0],
			C_frag.data[i][1],
			C_frag.data[i][2],
			C_frag.data[i][3],
			C_sts_addr + i * 8 * sizeof(float4));
	}

	__syncthreads();

	uint32_t m_guard = m < m_idx ? 0 : m - m_idx;

#pragma unroll
	for (int i = 0; i < 16; ++i) {
		stg32(C_lds_ptr[i * 32],
			C_stg_ptr + i * n,
			i < m_guard && n_idx < n);
	}
}


__device__ __forceinline__ void stg128(const float &reg0, const float &reg1,
	const float &reg2, const float &reg3,
	const float *addr) {
	asm volatile("st.global.v4.f32 [%0], {%1, %2, %3, %4};\n"
		:
	: "l"(addr), "f"(reg0), "f"(reg1), "f"(reg2), "f"(reg3));
}

__global__ __launch_bounds__(256, 2) void sgemm_128x128x8(uint32_t m,
	uint32_t n,
	uint32_t k,
	float *A,
	float *B,
	float *C) {

	__shared__ __align__(8 * 1024) char smem[20 * 1024];  // 16.5KB
	float *A_smem = reinterpret_cast<float *>(smem);
	float *B_smem = reinterpret_cast<float *>(smem + 8704);  // 8.5KB

	// A, B and C register fragment
	float A_frag[2][8];
	float B_frag[2][8];
	float C_frag[8][8];
#pragma unroll
	for (int i = 0; i < 8; ++i) {
#pragma unroll
		for (int j = 0; j < 8; ++j) {
			C_frag[i][j] = 0;
		}
	}

	const uint32_t lane_id = threadIdx.x % 32;
	const uint32_t warp_id = threadIdx.x / 32;

	// 4x8 threads each warp for FFMA
	const uint32_t mma_tid_x = (lane_id / 2) % 8;
	const uint32_t mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2);

	// A_tile & B_tile ldg pointer
	int from_a = (blockIdx.y * 64 + (threadIdx.x % 64) / 4 * 4) * k + (threadIdx.x % 64) % 4 + (threadIdx.x / 64) * (k / 4);
	int from_b = ((threadIdx.x % 64) / 16 + (threadIdx.x / 64) * (k / 4)) * n + blockIdx.x * 64 + (threadIdx.x % 64) % 16 * 4;
		// A_tile & B_tile sts/lds pointer
		// using uint32_t pointer for faster double buffer switch

	uint32_t A_lds_addr = smem_u32addr(
		A_smem + (warp_id % 2) * 32 + mma_tid_y * 4 + (threadIdx.x / 64) * 68 * 4);
	uint32_t B_lds_addr = smem_u32addr(
		B_smem + mma_tid_x * 4 + (threadIdx.x / 64) * 64 * 4);

	float4 b_ldg_reg;
	float a_ldg_reg[4];


	uint32_t a_sts_addr = smem_u32addr(A_smem + ((threadIdx.x % 64) % 4) * 68 + ((threadIdx.x % 64) / 4) * 4 + (threadIdx.x / 64) * 68 * 4);
	uint32_t b_sts_addr = smem_u32addr(B_smem + ((threadIdx.x % 64) / 16) * 64 + ((threadIdx.x % 64) % 16) * 4 + (threadIdx.x / 64) * 64 * 4);
	// 1'st A&B tile loaded before the k_tile loop

	uint32_t k_tiles = (k / 4 + 3) / 4 - 1;
	uint32_t first_k_tile = k / 4 - k_tiles * 4 + (k / 4)*(threadIdx.x / 64);
	// load 1'st tile to shared memory
	{
		// load first
		// load gmem to smem for ashare
#pragma unroll
		for (int i = 0; i < 4; ++i) {
			if ((threadIdx.x % 64) % 4 + (threadIdx.x / 64)*(k / 4) < first_k_tile  &&  blockIdx.y * 64 + (threadIdx.x % 64) / 4 * 4 + i < m) {
				ldg32_nc_0(a_ldg_reg[i], (const char *)(A + from_a) + i * k * sizeof(float));
			}
			else {
				a_ldg_reg[i] = 0;
			}
		}
		sts128(a_ldg_reg[0], a_ldg_reg[1], a_ldg_reg[2], a_ldg_reg[3], a_sts_addr);

		// load gmem to smem for bshare


		if (from_b < (1 + (threadIdx.x % 64) / 16 + (threadIdx.x / 64)*(k / 4))*n && (threadIdx.x % 64) / 16 + (threadIdx.x / 64)*(k / 4) < first_k_tile) {
			b_ldg_reg = FETCH_FLOAT4(B[from_b]);
		}
		else {
			b_ldg_reg = float4{ 0, 0, 0, 0 };
		}

		FETCH_FLOAT4(B_smem[((threadIdx.x % 64) / 16) * 64 + ((threadIdx.x % 64) % 16) * 4 + (threadIdx.x / 64) * 64 * 4]) = b_ldg_reg;
		__syncthreads();
		// add offset and flip flag
		from_a += k / 4 - k_tiles * 4;
		from_b += (k / 4 - k_tiles * 4) * n;
		a_sts_addr += 68 * 4 * 4 * sizeof(float);
		b_sts_addr += 64 * 4 * 4 * sizeof(float);
	}



		// load 1'st fragment
	lds128(A_frag[0][0], A_frag[0][1], A_frag[0][2], A_frag[0][3],
		A_lds_addr);
	lds128(A_frag[0][4], A_frag[0][5], A_frag[0][6], A_frag[0][7],
		A_lds_addr + 16 * sizeof(float));
	lds128(B_frag[0][0], B_frag[0][1], B_frag[0][2], B_frag[0][3],
		B_lds_addr);
	lds128(B_frag[0][4], B_frag[0][5], B_frag[0][6], B_frag[0][7],
		B_lds_addr + 32 * sizeof(float));



	int jump = 0;
	// k_tiles loop
	for (; k_tiles > 0; --k_tiles) {
		jump ^= 1;
#pragma unroll
		for (int k_frag = 0; k_frag < 4; ++k_frag) {
			// store next A&B tile to shared memory
			if (k_frag == 3) {
				sts128(a_ldg_reg[0], a_ldg_reg[1], a_ldg_reg[2], a_ldg_reg[3], a_sts_addr);
				sts128(b_ldg_reg.x, b_ldg_reg.y, b_ldg_reg.z, b_ldg_reg.w, b_sts_addr);
				__syncthreads();

				// switch double buffer
				if (jump == 1) {
					A_lds_addr += 68 * 4 * 4 * sizeof(float);
					B_lds_addr += 64 * 4 * 4 * sizeof(float);
					a_sts_addr -= 68 * 4 * 4 * sizeof(float);
					b_sts_addr -= 64 * 4 * 4 * sizeof(float);
				}
				else {
					A_lds_addr -= 68 * 4 * 4 * sizeof(float);
					B_lds_addr -= 64 * 4 * 4 * sizeof(float);
					a_sts_addr += 68 * 4 * 4 * sizeof(float);
					b_sts_addr += 64 * 4 * 4 * sizeof(float);
				}
				// ldg pointer for next tile
				from_a += 4;
				from_b += 4 * n;
			}

			// load next A&B fragment from shared memory to register
			lds128(A_frag[(k_frag + 1) % 2][0],
				A_frag[(k_frag + 1) % 2][1],
				A_frag[(k_frag + 1) % 2][2],
				A_frag[(k_frag + 1) % 2][3],
				A_lds_addr + (k_frag + 1) % 4 * 68 * sizeof(float));
			lds128(A_frag[(k_frag + 1) % 2][4],
				A_frag[(k_frag + 1) % 2][5],
				A_frag[(k_frag + 1) % 2][6],
				A_frag[(k_frag + 1) % 2][7],
				A_lds_addr + ((k_frag + 1) % 4 * 68 + 16) * sizeof(float));
			lds128(B_frag[(k_frag + 1) % 2][0],
				B_frag[(k_frag + 1) % 2][1],
				B_frag[(k_frag + 1) % 2][2],
				B_frag[(k_frag + 1) % 2][3],
				B_lds_addr + (k_frag + 1) % 4 * 64 * sizeof(float));
			lds128(B_frag[(k_frag + 1) % 2][4],
				B_frag[(k_frag + 1) % 2][5],
				B_frag[(k_frag + 1) % 2][6],
				B_frag[(k_frag + 1) % 2][7],
				B_lds_addr + ((k_frag + 1) % 4 * 64 + 32) * sizeof(float));
									// load next A&B tile
			if (k_frag == 0) {
				if (from_b < (1 + (threadIdx.x % 64) / 16 + (threadIdx.x / 64)*(k / 4))*n + (-k_tiles * 4 + k / 4)*n && (-k_tiles * 4 + k / 4) + (threadIdx.x % 64) / 16 + (threadIdx.x / 64)*(k / 4) < (threadIdx.x / 64 + 1)*(k / 4)) {

					b_ldg_reg = FETCH_FLOAT4(B[from_b]);
				}
				else {
					b_ldg_reg = float4{ 0, 0, 0, 0 };
				}

#pragma unroll
				for (int i = 0; i < 4; ++i) {
					if ((threadIdx.x % 64) % 4 + (threadIdx.x / 64)*(k / 4) + (-k_tiles * 4 + k / 4) < k && blockIdx.y * 64 + (threadIdx.x % 64) / 4 * 4 + i < m) {
						ldg32_nc_0(a_ldg_reg[i], (const char *)(A + from_a) + i * k * sizeof(float));
					}
					else {
						a_ldg_reg[i] = 0;
					}
				}
			}

			// FFMA loop
#pragma unroll
			for (int i = 0; i < 8; ++i) {
#pragma unroll
				for (int j = 0; j < 8; ++j) {
					C_frag[i][j] += A_frag[k_frag % 2][i] *
						B_frag[k_frag % 2][j];
				}
			}
		}
	}

	// FFMA for the last tile
#pragma unroll
	for (int k_frag = 0; k_frag < 4; ++k_frag) {
		if (k_frag < 3) {
			// load next A&B fragment from shared memory to register
			lds128(A_frag[(k_frag + 1) % 2][0],
				A_frag[(k_frag + 1) % 2][1],
				A_frag[(k_frag + 1) % 2][2],
				A_frag[(k_frag + 1) % 2][3],
				A_lds_addr + (k_frag + 1) % 4 * 68 * sizeof(float));
			lds128(A_frag[(k_frag + 1) % 2][4],
				A_frag[(k_frag + 1) % 2][5],
				A_frag[(k_frag + 1) % 2][6],
				A_frag[(k_frag + 1) % 2][7],
				A_lds_addr + ((k_frag + 1) % 4 * 68 + 16) * sizeof(float));
			lds128(B_frag[(k_frag + 1) % 2][0],
				B_frag[(k_frag + 1) % 2][1],
				B_frag[(k_frag + 1) % 2][2],
				B_frag[(k_frag + 1) % 2][3],
				B_lds_addr + (k_frag + 1) % 4 * 64 * sizeof(float));
			lds128(B_frag[(k_frag + 1) % 2][4],
				B_frag[(k_frag + 1) % 2][5],
				B_frag[(k_frag + 1) % 2][6],
				B_frag[(k_frag + 1) % 2][7],
				B_lds_addr + ((k_frag + 1) % 4 * 64 + 32) * sizeof(float));
		}

						// FFMA loop
#pragma unroll
		for (int i = 0; i < 8; ++i) {
#pragma unroll
			for (int j = 0; j < 8; ++j) {
				C_frag[i][j] += A_frag[k_frag % 2][i] *
					B_frag[k_frag % 2][j];
			}
		}
	}


			// C_tile write back, reuse A&B tile shared memory buffer
	uint32_t C_sts_addr = smem_u32addr((float4 *)(smem + warp_id * 2048) +
		mma_tid_y * 4 * 8 + mma_tid_x);
	uint32_t C_lds_ptr = smem_u32addr(A_smem + (mma_tid_y * 4 * 8 + mma_tid_x) * 4 + (warp_id % 2) * 16 * 32);
	uint32_t C_lds_addr = smem_u32addr(A_smem + threadIdx.x / 8 * 32 + (threadIdx.x % 8) * 4);



	uint32_t m_idx = blockIdx.y * 64;

	if (m_idx >= m) {
		return;
	}
	else if (m_idx + 32 <= m) {

#pragma unroll
		for (int i = 0; i < 2; ++i) {


			for (int j = 0; j < 2; ++j) {
				__syncthreads();

#pragma unroll
				for (int p = 0; p < 4; ++p) {
					sts128(C_frag[i * 4 + p][j * 4],
						C_frag[i * 4 + p][j * 4 + 1],
						C_frag[i * 4 + p][j * 4 + 2],
						C_frag[i * 4 + p][j * 4 + 3],
						C_sts_addr + p * 8 * sizeof(float4));
				}

				__syncthreads();


				lds128(B_frag[0][0], B_frag[0][1], B_frag[0][2], B_frag[0][3], C_lds_addr);
				lds128(B_frag[0][4], B_frag[0][5], B_frag[0][6], B_frag[0][7],
					C_lds_addr + (32 * 32) * sizeof(float));
				lds128(B_frag[1][0], B_frag[1][1], B_frag[1][2], B_frag[1][3],
					C_lds_addr + (32 * 32 * 2) * sizeof(float));
				lds128(B_frag[1][4], B_frag[1][5], B_frag[1][6], B_frag[1][7],
					C_lds_addr + (32 * 32 * 3) * sizeof(float));


				B_frag[0][0] += (B_frag[0][4] + B_frag[1][0] + B_frag[1][4]);
				B_frag[0][1] += (B_frag[0][5] + B_frag[1][1] + B_frag[1][5]);
				B_frag[0][2] += (B_frag[0][6] + B_frag[1][2] + B_frag[1][6]);
				B_frag[0][3] += (B_frag[0][7] + B_frag[1][3] + B_frag[1][7]);

				if (blockIdx.y * 64 + (threadIdx.x % 128) / 8 + i * 16 + (threadIdx.x / 128) * 32 < m&& blockIdx.x * 64 + (threadIdx.x % 8) * 4 + j * 32 < n) {
					stg128(B_frag[0][0], B_frag[0][1], B_frag[0][2], B_frag[0][3], C + (blockIdx.y * 64 + (threadIdx.x % 128) / 8 + i * 16 + (threadIdx.x / 128) * 32)*n + blockIdx.x * 64 + (threadIdx.x % 8) * 4 + j * 32);
				}
			}
		}
	}
	//	else {
	//#pragma unroll
	//		for (int i = 0; i < 2; ++i) {
	//#pragma unroll
	//			for (int j = 0; j < 2; ++j) {
	//				StgFrag stg_frag(C_frag, j, i);

	//				C_tile_wb(stg_frag,
	//					C_stg_ptr + i * 16 * n + j * 32,
	//					C_lds_ptr,
	//					C_sts_addr,
	//					m,
	//					n,
	//					m_idx + i * 16,
	//					n_idx + j * 32);
	//			}
	//		}
	//	}
}






float* random_matrix(int row, int col) {
	float* mat = new float[row * col];


	for (int i = 0; i < row; ++i) {
		for (int j = 0; j < col; ++j) {
			if (i * col + j + 1 < 10) {
				mat[i * col + j] = i * col + j + 1;
			}
			else {
				mat[i * col + j] = 0.5;
			}
		}
	}

	return mat;
}

// float* random_matrix(int row, int col) {
// 	float* mat = new float[row * col];
//
//
// 	for (int i = 0; i < row; ++i) {
// 		for (int j = 0; j < col; ++j) {
// 			mat[i * col + j] = 0;
// 		}
// 	}
//
// 	return mat;
// }


void print_mat(float* mat, int row, int col) {
	/*Display the matrix for visualizatoin*/
	for (int i = 0; i < row; ++i) {
		for (int j = 0; j < col; ++j) {
			cout << mat[i * col + j] << " ";
		}cout << endl;
	}
	cout << "\n" << endl;
}



int main()
{
	const int m = 3072, k = 3072, n = 64;
	float* a = random_matrix(m, k);
	float* b = random_matrix(k, n);
	float* c = new float[m*n];

	float* dev_a, *dev_b, *dev_c;

	cudaMalloc((void**)&dev_a, m * k * sizeof(float));
	cudaMalloc((void**)&dev_b, k * n * sizeof(float));
	cudaMalloc((void**)&dev_c, m * n * sizeof(float));

	cudaMemcpy(dev_a, a, m * k * sizeof(float), cudaMemcpyHostToDevice);
	cudaMemcpy(dev_b, b, k * n * sizeof(float), cudaMemcpyHostToDevice);


	constexpr int BLOCK = 64;
	dim3 grid((n + BLOCK - 1) / BLOCK, (m + 64 - 1) / 64);
	int repeat = 1;


	cudaEvent_t start, stop;
	cudaEventCreate(&start);
	cudaEventCreate(&stop);
	cudaEventRecord(start);
	cudaEventQuery(start);

	for (int i = 0; i < repeat; i++) {
		sgemm_128x128x8 << <grid, 256 >> > (m, n, k, dev_a, dev_b, dev_c);
	}


	cudaEventRecord(stop);
	cudaEventSynchronize(stop);
	float elapsed_time;
	cudaEventElapsedTime(&elapsed_time, start, stop);
	printf("Time = %g ms .\n", elapsed_time / repeat);
	cudaEventDestroy(start);
	cudaEventDestroy(stop);


	cudaMemcpy(c, dev_c, m * n * sizeof(float), cudaMemcpyDeviceToHost);
	bool chk = check(a, b, c, m, n, k);
	printf("Matrix_C check: %s\n", chk ? "OK" : "Failed");
	//cout << 'a' << endl;
	//print_mat(a, m, k);
	//cout << 'b' << endl;
	//print_mat(b, k, n);
	//cout << 'c' << endl;
	//print_mat(c, m, n);
}

Arsmart123 avatar Jun 27 '22 03:06 Arsmart123

you can take a look at this igemm example: https://github.com/NVIDIA/cutlass/blob/master/test/unit/gemm/threadblock/mma_pipelined_slicedk.cu#L42-L64 . We haven't added slicek support to sgemm, but the concepts are the same.

The shared memory load code is in https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/warp/mma_simt_tile_iterator.h#L1433-L1884. Template PartitionsK controls the slice number.

hwu36 avatar Jun 27 '22 15:06 hwu36

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] avatar Aug 04 '22 14:08 github-actions[bot]

@Arsmart123 were you able to resolve your issues?

mnicely avatar Sep 17 '22 12:09 mnicely