rocBLAS
rocBLAS copied to clipboard
[Bug]: rocblas_gemm_ex with m==1 fp16 inputs/outputs f32 compute slower than a quite naive gemv kernel on MI100
Describe the bug
As described in the title, rocblas_gemm_ex seems quite suboptimal when m==1 inputs/outputs are fp16 and compute is fp32 on MI100. A quite naive kernel I implemented beats it.
Causes https://github.com/ROCm/pytorch/issues/1408 in pytorch. It make LLM inference on Mistral 7b fp16 slower compared to what it could easily be.
To Reproduce
Here is a C++ reproducer:
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include <iostream>
#include <chrono>
#include <functional>
#define ROWS_PER_BLOCK 4
#define THREADS_PER_BLOCK 64
#define DIV_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define FULL_MASK32 0xffffffff
#define FULL_MASK64 0xffffffffffffffff
#ifdef __CUDA_ARCH__
#define __xx_shfl_down(mask, val, offset) __shfl_down_sync(mask, val, offset)
#elif defined(__HIP_PLATFORM_AMD__) // AMD
#define __xx_shfl_down(mask, val, offset) __shfl_down(val, offset)
#else
#error "Unsupported compiler"
#endif
__device__ float warpReduce(float val) {
if (warpSize == 32) {
for (int offset = 16; offset > 0; offset /= 2)
val += __xx_shfl_down(FULL_MASK32, val, offset);
}
if (warpSize == 64) {
for (int offset = 32; offset > 0; offset /= 2)
val += __xx_shfl_down(FULL_MASK64, val, offset);
}
return val;
}
static inline void __device__ dot2(float& acc, const float2& a, const float2& b) {
acc += a.x * b.x;
acc += a.y * b.y;
}
template <typename T>
static inline const T* __device__ addr(const T* p, unsigned index) {
// helps the AMDGPU compiler understand it can use the sgrp pair + single vgpr addressing mode
unsigned byte_offset = sizeof(T) * index;
const uint8_t* p8 = (const uint8_t*)p;
return (const T*) (p8 + byte_offset);
}
__global__ void muillm_gemv_kernel(
const half* __restrict__ W, // weight matrix - size N x K
const half* __restrict__ B, // optional bias - size N
const half* __restrict__ X, // input = size K
half* __restrict__ Y, // output - size N
unsigned N,
unsigned K
) {
int warpCounts = THREADS_PER_BLOCK / warpSize;
int warpId = threadIdx.x / warpSize;
int laneId = threadIdx.x % warpSize;
// can process ROWS_PER_BLOCK rows
// shared state to do the reductions
__shared__ float shared_accs[ROWS_PER_BLOCK];
// initialize the shared memory
if (threadIdx.x < ROWS_PER_BLOCK) {
shared_accs[threadIdx.x] = 0.f;
}
if (THREADS_PER_BLOCK > warpSize) {
__syncthreads();
}
{
int current_row = blockIdx.x * ROWS_PER_BLOCK + 0;
if (current_row + 3 < N) {
// compute the t-th element of Y. by doing the dot product with the
// t-th row of W
const half* W0 = &W[(current_row + 0) * K];
const half* W1 = &W[(current_row + 1) * K];
const half* W2 = &W[(current_row + 2) * K];
const half* W3 = &W[(current_row + 3) * K];
float acc0 = 0.f;
float acc1 = 0.f;
float acc2 = 0.f;
float acc3 = 0.f;
// do the dot product
{
unsigned k; // should be 2 * tidx ?
//*
for (k = threadIdx.x * 2; k + 1 < K; k += (THREADS_PER_BLOCK * 2)) {
// vectorized
float2 x = __half22float2(*((const half2*)addr(X, k)));
float2 w0 = __half22float2(*((const half2*)addr(W0, k)));
float2 w1 = __half22float2(*((const half2*)addr(W1, k)));
float2 w2 = __half22float2(*((const half2*)addr(W2, k)));
float2 w3 = __half22float2(*((const half2*)addr(W3, k)));
dot2(acc0, w0, x);
dot2(acc1, w1, x);
dot2(acc2, w2, x);
dot2(acc3, w3, x);
}
if (k < K) {
// remainder
float x = __half2float(*addr(X,k));
float w0 = __half2float(*addr(W0,k));
float w1 = __half2float(*addr(W1,k));
float w2 = __half2float(*addr(W2,k));
float w3 = __half2float(*addr(W3,k));
acc0 += w0 * x;
acc1 += w1 * x;
acc2 += w2 * x;
acc3 += w3 * x;
}
}
// warp reduce
acc0 = warpReduce(acc0);
acc1 = warpReduce(acc1);
acc2 = warpReduce(acc2);
acc3 = warpReduce(acc3);
// reduce accross warps
if (laneId == 0) {
atomicAdd(&shared_accs[0], acc0);
atomicAdd(&shared_accs[1], acc1);
atomicAdd(&shared_accs[2], acc2);
atomicAdd(&shared_accs[3], acc3);
}
} else {
for (int i = 0; i < ROWS_PER_BLOCK; i++) {
// compute the t-th element of Y. by doing the dot product with the
// t-th row of W
int current_row = blockIdx.x * ROWS_PER_BLOCK + i;
if (current_row >= N)
break;
const half* W_ = &W[current_row * K];
// do the dot product
float acc = 0.f;
for (int k = threadIdx.x; k < K; k += THREADS_PER_BLOCK) {
float w = __half2float(W_[k]);
acc += w * __half2float(X[k]);
}
// warp reduce
acc = warpReduce(acc);
// reduce accross warps
if (laneId == 0) {
atomicAdd(&shared_accs[i], acc);
}
}
}
}
if (THREADS_PER_BLOCK > warpSize) {
__syncthreads();
}
// write out the results
{
if (threadIdx.x >= ROWS_PER_BLOCK)
return;
int current_row = blockIdx.x * ROWS_PER_BLOCK + threadIdx.x;
if (current_row < N) {
float acc = shared_accs[threadIdx.x]; // read the fully reduced value
if (B != nullptr) { // add the bias first if there is one
acc += __half2float(B[current_row]);
}
// write the output value
Y[current_row] = __float2half(acc);
}
}
}
void muillm_linear_forward_cuda(
const half* __restrict__ W, // size N x K
const half* __restrict__ B, // size N
const half* __restrict__ X, // size K
half* __restrict__ Y, // size N
unsigned N,
unsigned K) {
const int threads_per_blocks = THREADS_PER_BLOCK;
const int num_blocks = DIV_ROUND_UP(N, ROWS_PER_BLOCK);
muillm_gemv_kernel<<<num_blocks, threads_per_blocks, 0, 0>>>(
W,
B,
X,
Y,
N,
K
);
}
static inline void rocblas_sgemv(rocblas_handle handle,
const half* __restrict__ W, // size N x K
const half* __restrict__ X, // size K
half* __restrict__ Y, // size N
unsigned N,
unsigned K) {
float alpha = 1.0f;
float beta = 0.f;
// adapted for row major from https://stackoverflow.com/questions/56043539/cublassgemm-row-major-multiplication
rocblas_gemm_ex(handle,
rocblas_operation_none /*transA*/,
rocblas_operation_none /*transB*/,
1 /*m*/,
N /*n*/,
K /*k*/,
&alpha,
X /*a*/,
rocblas_datatype_f16_r /*a_type*/,
1 /*lda*/,
W /*b*/,
rocblas_datatype_f16_r /*b_type*/,
K /*ldb*/,
&beta,
nullptr /*c*/,
rocblas_datatype_f16_r /*c_type*/,
1 /*ldc*/,
Y /*d*/,
rocblas_datatype_f16_r /*d_type*/,
1 /*ldd*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
0 /*flags*/);
}
size_t timeus_func(size_t count, std::function<void(int)> f) {
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
f(count);
hipDeviceSynchronize();
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / count;
}
int main(int argc, char** argv) {
int in_features=4096, out_features=14336;
int tot_features = in_features * out_features;
// allocate matrices and vectors
half* x_small = nullptr;
half* x_big = nullptr;
half* w_up = nullptr;
half* w_down = nullptr;
std::cout<<"Allocating memory..."<<std::endl;
if (hipMalloc(&x_small, sizeof(half) * in_features) != hipSuccess) {
return -1;
}
if (hipMalloc(&x_big, sizeof(half) * out_features) != hipSuccess) {
return -1;
}
if (hipMalloc(&w_up, sizeof(half) * tot_features) != hipSuccess) {
return -1;
}
if (hipMalloc(&w_down, sizeof(half) * tot_features) != hipSuccess) {
return -1;
}
// set memory
std::cout<<"Setting memory..."<<std::endl;
if (hipMemsetD16(x_small, 0, in_features) != hipSuccess) {
return -1;
}
if (hipMemsetD16(x_big, 0, out_features) != hipSuccess) {
return -1;
}
if (hipMemsetD16(w_up, 0, tot_features) != hipSuccess) {
return -1;
}
if (hipMemsetD16(w_down, 0, tot_features) != hipSuccess) {
return -1;
}
//
std::cout<<"Running..."<<std::endl;
int count = 10000;
{
auto mui_prof = [=] (int count) {
for (int i = 0; i < count; i++) {
muillm_linear_forward_cuda(w_up, nullptr, x_small, x_big, out_features, in_features);
muillm_linear_forward_cuda(w_down, nullptr, x_big, x_small, in_features, out_features);
}
};
// warmup
size_t discarded = timeus_func(
10,
mui_prof
);
// measurement
size_t mui_time = timeus_func(
count,
mui_prof
);
std::cout<<"mui: "<<mui_time<<"us/loop"<<std::endl;
}
{// rocblas
rocblas_initialize();
rocblas_handle handle;
if(rocblas_create_handle(&handle) != rocblas_status_success) return -3;
auto rocblas_prof = [=] (int count) {
for (int i = 0; i < count; i++) {
rocblas_sgemv(handle, w_up, x_small, x_big, out_features, in_features);
rocblas_sgemv(handle, w_down, x_big, x_small, in_features, out_features);
}
};
// warmup
size_t discarded = timeus_func(
10,
rocblas_prof
);
// measurement
size_t rocblas_time = timeus_func(
count,
rocblas_prof
);
std::cout<<"rocblas: "<<rocblas_time<<"us/loop"<<std::endl;
}
std::cout<<"DONE"<<std::endl;
return 0;
}
Expected behavior
It should be at least as fast as my naive kernel. But running the above, I get:
Allocating memory...
Setting memory...
Running...
mui: 227us/loop
rocblas: 386us/loop
hipblas: 386us/loop
DONE
Environment
Hardware | description |
---|---|
CPU | AMD Ryzen 7 5800X3D 8-Core Processor |
GPU | AMD Instinct MI100 |
Software | version |
---|---|
rocm-core | v6.0.2.60002-115~22.04 |
rocblas | v4.0.0.60002-115~22.04 |
Additional context
Add any other context about the problem here.
EDIT: put a better kernel than originally included one EDIT2: put a better kernel