rocBLAS icon indicating copy to clipboard operation
rocBLAS copied to clipboard

[Bug]: rocblas_gemm_ex with m==1 fp16 inputs/outputs f32 compute slower than a quite naive gemv kernel on MI100

Open Epliz opened this issue 9 months ago • 18 comments

Describe the bug

As described in the title, rocblas_gemm_ex seems quite suboptimal when m==1 inputs/outputs are fp16 and compute is fp32 on MI100. A quite naive kernel I implemented beats it.

Causes https://github.com/ROCm/pytorch/issues/1408 in pytorch. It make LLM inference on Mistral 7b fp16 slower compared to what it could easily be.

To Reproduce

Here is a C++ reproducer:

#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include <iostream>
#include <chrono>
#include <functional>


#define ROWS_PER_BLOCK 4
#define THREADS_PER_BLOCK 64

#define DIV_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

#define FULL_MASK32 0xffffffff
#define FULL_MASK64 0xffffffffffffffff

#ifdef  __CUDA_ARCH__
#define __xx_shfl_down(mask, val, offset) __shfl_down_sync(mask, val, offset)
#elif defined(__HIP_PLATFORM_AMD__) // AMD
#define __xx_shfl_down(mask, val, offset) __shfl_down(val, offset)
#else
#error "Unsupported compiler"
#endif

__device__ float warpReduce(float val) {
  if (warpSize == 32) {
    for (int offset = 16; offset > 0; offset /= 2)
      val += __xx_shfl_down(FULL_MASK32, val, offset);
  }
  if (warpSize == 64) {
    for (int offset = 32; offset > 0; offset /= 2)
      val += __xx_shfl_down(FULL_MASK64, val, offset);

  }
  return val;
}

static inline void __device__ dot2(float& acc, const float2& a, const float2& b) {
  acc += a.x * b.x;
  acc += a.y * b.y;
}

template <typename T>
static inline const T* __device__ addr(const T* p, unsigned index) {
  // helps the AMDGPU compiler understand it can use the sgrp pair + single vgpr addressing mode
  unsigned byte_offset = sizeof(T) * index;
  const uint8_t* p8 = (const uint8_t*)p;
  return (const T*) (p8 + byte_offset);
}

__global__ void muillm_gemv_kernel(
    const half* __restrict__ W, // weight matrix - size N x K
    const half* __restrict__ B, // optional bias - size N
    const half* __restrict__ X, // input = size K
    half* __restrict__ Y, // output - size N
    unsigned N,
    unsigned K
) {
  int warpCounts = THREADS_PER_BLOCK / warpSize;
  int warpId = threadIdx.x / warpSize;
  int laneId = threadIdx.x % warpSize;

  // can process ROWS_PER_BLOCK rows
  // shared state to do the reductions
  __shared__ float shared_accs[ROWS_PER_BLOCK];

  // initialize the shared memory
  if (threadIdx.x < ROWS_PER_BLOCK) {
    shared_accs[threadIdx.x] = 0.f;
  }
  if (THREADS_PER_BLOCK > warpSize) {
    __syncthreads();
  }

  {
    int current_row = blockIdx.x * ROWS_PER_BLOCK + 0;
    if (current_row + 3 < N) {

      // compute the t-th element of Y. by doing the dot product with the
      // t-th row of W
      const half* W0 = &W[(current_row + 0) * K];
      const half* W1 = &W[(current_row + 1) * K];
      const half* W2 = &W[(current_row + 2) * K];
      const half* W3 = &W[(current_row + 3) * K];

      float acc0 = 0.f;
      float acc1 = 0.f;
      float acc2 = 0.f;
      float acc3 = 0.f;

      // do the dot product
      {
        unsigned k; // should be 2 * tidx ?
        //*
        for (k = threadIdx.x * 2; k + 1 < K; k += (THREADS_PER_BLOCK * 2)) {
          // vectorized
          float2 x = __half22float2(*((const half2*)addr(X, k)));
          float2 w0 = __half22float2(*((const half2*)addr(W0, k)));
          float2 w1 = __half22float2(*((const half2*)addr(W1, k)));
          float2 w2 = __half22float2(*((const half2*)addr(W2, k)));
          float2 w3 = __half22float2(*((const half2*)addr(W3, k)));

          dot2(acc0, w0, x);
          dot2(acc1, w1, x);
          dot2(acc2, w2, x);
          dot2(acc3, w3, x);
        }
        if (k < K) {
          // remainder
          float x = __half2float(*addr(X,k));
          float w0 = __half2float(*addr(W0,k));
          float w1 = __half2float(*addr(W1,k));
          float w2 = __half2float(*addr(W2,k));
          float w3 = __half2float(*addr(W3,k));
          acc0 += w0 * x;
          acc1 += w1 * x;
          acc2 += w2 * x;
          acc3 += w3 * x;
        }
      }

      // warp reduce
      acc0 = warpReduce(acc0);
      acc1 = warpReduce(acc1);
      acc2 = warpReduce(acc2);
      acc3 = warpReduce(acc3);

      // reduce accross warps
      if (laneId == 0) {
        atomicAdd(&shared_accs[0], acc0);
        atomicAdd(&shared_accs[1], acc1);
        atomicAdd(&shared_accs[2], acc2);
        atomicAdd(&shared_accs[3], acc3);
      }
    } else {
      for (int i = 0; i < ROWS_PER_BLOCK; i++) {
        // compute the t-th element of Y. by doing the dot product with the
        // t-th row of W
        int current_row = blockIdx.x * ROWS_PER_BLOCK + i;

        if (current_row >= N)
          break;

        const half* W_ = &W[current_row * K];
      
        // do the dot product
        float acc = 0.f;
        for (int k = threadIdx.x; k < K; k += THREADS_PER_BLOCK) {
          float w = __half2float(W_[k]);
          acc += w * __half2float(X[k]);
        }

        // warp reduce
        acc = warpReduce(acc);

        // reduce accross warps
        if (laneId == 0) {
          atomicAdd(&shared_accs[i], acc);
        }
      }
    }
  }

  if (THREADS_PER_BLOCK > warpSize) {
    __syncthreads();
  }

  // write out the results
  {
    if (threadIdx.x >= ROWS_PER_BLOCK)
      return;

    int current_row = blockIdx.x * ROWS_PER_BLOCK + threadIdx.x;

    if (current_row < N) {
      float acc = shared_accs[threadIdx.x]; // read the fully reduced value
      if (B != nullptr) { // add the bias first if there is one
        acc += __half2float(B[current_row]);
      }

      // write the output value
      Y[current_row] = __float2half(acc);
    }
  }
}
void muillm_linear_forward_cuda(
    const half* __restrict__ W, // size N x K
    const half* __restrict__ B, // size N
    const half* __restrict__ X, // size K
    half* __restrict__ Y, // size N
    unsigned N,
    unsigned K) {

  const int threads_per_blocks = THREADS_PER_BLOCK;
  const int num_blocks = DIV_ROUND_UP(N, ROWS_PER_BLOCK);

  muillm_gemv_kernel<<<num_blocks, threads_per_blocks, 0, 0>>>(
    W,
    B,
    X,
    Y,
    N,
    K
  );
}

static inline void rocblas_sgemv(rocblas_handle handle,
    const half* __restrict__ W, // size N x K
    const half* __restrict__ X, // size K
    half* __restrict__ Y, // size N
    unsigned N,
    unsigned K) {
  float alpha = 1.0f;
  float beta = 0.f;

  // adapted for row major from https://stackoverflow.com/questions/56043539/cublassgemm-row-major-multiplication
  rocblas_gemm_ex(handle,
                  rocblas_operation_none /*transA*/,
                  rocblas_operation_none /*transB*/,
                  1 /*m*/,
                  N /*n*/,
                  K /*k*/,
                  &alpha,
                  X /*a*/,
                  rocblas_datatype_f16_r /*a_type*/,
                  1 /*lda*/,
                  W /*b*/,
                  rocblas_datatype_f16_r /*b_type*/,
                  K /*ldb*/,
                  &beta,
                  nullptr /*c*/,
                  rocblas_datatype_f16_r /*c_type*/,
                  1 /*ldc*/,
                  Y /*d*/,
                  rocblas_datatype_f16_r /*d_type*/,
                  1 /*ldd*/,
                  rocblas_datatype_f32_r /*compute_type*/,
                  rocblas_gemm_algo_standard /*algo*/,
                  0 /*solution_index*/,
                  0 /*flags*/);
}

size_t timeus_func(size_t count, std::function<void(int)> f) {
  std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
  f(count);
  hipDeviceSynchronize();

  std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();

  return std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() / count;
}

int main(int argc, char** argv) {
  int in_features=4096, out_features=14336;
  int tot_features = in_features * out_features;

  // allocate matrices and vectors
  half* x_small = nullptr;
  half* x_big = nullptr;
  half* w_up = nullptr;
  half* w_down = nullptr;

  std::cout<<"Allocating memory..."<<std::endl;
  if (hipMalloc(&x_small, sizeof(half) * in_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&x_big, sizeof(half) * out_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&w_up, sizeof(half) * tot_features) != hipSuccess) {
    return -1;
  }

  if (hipMalloc(&w_down, sizeof(half) * tot_features) != hipSuccess) {
    return -1;
  }

  // set memory
  std::cout<<"Setting memory..."<<std::endl;
  if (hipMemsetD16(x_small, 0, in_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(x_big, 0, out_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(w_up, 0, tot_features) != hipSuccess) {
    return -1;
  }
  if (hipMemsetD16(w_down, 0, tot_features) != hipSuccess) {
    return -1;
  }

  //
  std::cout<<"Running..."<<std::endl;

  int count = 10000;

  {

    auto mui_prof = [=] (int count) {
      for (int i = 0; i < count; i++) {
        muillm_linear_forward_cuda(w_up, nullptr, x_small, x_big, out_features, in_features);
        muillm_linear_forward_cuda(w_down, nullptr, x_big, x_small, in_features, out_features);
      }
    };

    // warmup
    size_t discarded = timeus_func(
      10,
      mui_prof
    );

    // measurement
    size_t mui_time = timeus_func(
      count,
      mui_prof
    );

    std::cout<<"mui: "<<mui_time<<"us/loop"<<std::endl;
  }

  {// rocblas
    rocblas_initialize();
    rocblas_handle handle;
    if(rocblas_create_handle(&handle) != rocblas_status_success) return -3;

    auto rocblas_prof = [=] (int count) {
      for (int i = 0; i < count; i++) {
        rocblas_sgemv(handle, w_up, x_small, x_big, out_features, in_features);
        rocblas_sgemv(handle, w_down, x_big, x_small, in_features, out_features);
      }
    };

    // warmup
    size_t discarded = timeus_func(
      10,
      rocblas_prof
    );

    // measurement
    size_t rocblas_time = timeus_func(
      count,
      rocblas_prof
    );

    std::cout<<"rocblas: "<<rocblas_time<<"us/loop"<<std::endl;
  }

  std::cout<<"DONE"<<std::endl;
  return 0;
}

Expected behavior

It should be at least as fast as my naive kernel. But running the above, I get:

Allocating memory...
Setting memory...
Running...
mui: 227us/loop
rocblas: 386us/loop
hipblas: 386us/loop
DONE

Environment

Hardware description
CPU AMD Ryzen 7 5800X3D 8-Core Processor
GPU AMD Instinct MI100
Software version
rocm-core v6.0.2.60002-115~22.04
rocblas v4.0.0.60002-115~22.04

environment.txt

Additional context

Add any other context about the problem here.

EDIT: put a better kernel than originally included one EDIT2: put a better kernel

Epliz avatar May 05 '24 11:05 Epliz