stdBLAS icon indicating copy to clipboard operation
stdBLAS copied to clipboard

GEMM Strides/Extents to CBLAS mapping

Open crtrott opened this issue 3 years ago • 1 comments

Does this seem right? Posting for each matrix (extent(0),extent(1)), (stride(0),stride(1))

C A B defacto layouts CBLAS call
(M,N), (1,M) (M,K), (1,M) (K,N), (1,K) (left,left,left) gemm('N','N', M, N, K, 1., A.data(), K, B.data(), N, 1., C.data(), N)
(M,N), (1,M) (M,K), (K,1) (K,N), (1,K) (left,right,left) gemm('T','N', N, M, K, 1., A.data(), M, B.data(), N, 1., C.data(), N)
(M,N), (1,M) (M,K), (1,M) (K,N), (N,1) (left,left,right) gemm('N','T', M, N, K, 1., A.data(), K, B.data(), K, 1., C.data(), N)
(M,N), (1,M) (M,K), (K,1) (K,N), (N,1) (left,right,right) gemm('T','T', M, N, K, 1., A.data(), M, B.data(), K, 1., C.data(), N)
(M,N), (N,1) (M,K), (1, M) (K,N), (1,K) (right,left,left) gemm('T','T', N, M, K, 1., B.data(), N, A.data(), K, 1., C.data(), M)
(M,N), (N,1) (M,K), (K, 1) (K,N), (1,K) (right,right,left) gemm('T','N', N, M, K, 1., B.data(), N, A.data(), M, 1., C.data(), M)
(M,N), (N,1) (M,K), (1, M) (K,N), (N,1) (right,left,right) gemm('N','T', N, M, K, 1., B.data(), K, A.data(), K, 1., C.data(), M)
(M,N), (N,1) (M,K), (K,1) (K,N), (N,1) (right,right,right) gemm('N','N', N, M, K, 1., B.data(), K, A.data(), M, 1., C.data(), M)

crtrott avatar Feb 04 '22 23:02 crtrott

Here is the test code:

#include<Kokkos_Core.hpp>
#include<Kokkos_Random.hpp>

extern "C" void dgemm_(const char*, const char*, int*, int*, int*, double*, double*, int*, double*, int*, double*, double*, int*);

template<class AT, class BT, class CT>
void gemm(CT C, AT A, BT B) {
//  printf("C: %i %i %i %i\n",C.extent_int(0),C.extent_int(1),int(C.stride(0)),int(C.stride(1)));
//  printf("A: %i %i %i %i\n",A.extent_int(0),A.extent_int(1),int(A.stride(0)),int(A.stride(1)));
//  printf("B: %i %i %i %i\n",B.extent_int(0),B.extent_int(1),int(B.stride(0)),int(B.stride(1)));

  int M = C.extent(0);
  int N = C.extent(1);
  int K = A.extent(1);
  int LDA = A.stride(0)==1?A.extent(0):A.extent(1);
  int LDB = B.stride(0)==1?B.extent(0):B.extent(1);
  int LDC = C.stride(0)==1?C.extent(0):C.extent(1);
  double alpha = 1., beta = 0.;
  double* A_data = A.data();
  double* B_data = B.data();
  double* C_data = C.data();

  if(C.stride(0)==1) {
    if(A.stride(0)==1 && B.stride(0)==1)
      dgemm_("N","N",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
    if(A.stride(1)==1 && B.stride(0)==1)
      dgemm_("T","N",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
    if(A.stride(0)==1 && B.stride(1)==1)
      dgemm_("N","T",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
    if(A.stride(1)==1 && B.stride(1)==1)
      dgemm_("T","T",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
  } else if(C.stride(1)==1) {
    if(A.stride(0)==1 && B.stride(0)==1)
      dgemm_("T","T",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
    if(A.stride(1)==1 && B.stride(0)==1)
      dgemm_("T","N",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
    if(A.stride(0)==1 && B.stride(1)==1)
      dgemm_("N","T",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
    if(A.stride(1)==1 && B.stride(1)==1)
      dgemm_("N","N",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
  }
}

template<class LC, class LA, class LB>
void testgemm(int M, int N, int K) {
  Kokkos::View<double**,LA> A("A",M,K);
  Kokkos::View<double**,LB> B("B",K,N);
  Kokkos::View<double**,LC> C("C",M,N),C2("C2",M,N);

  Kokkos::Random_XorShift64_Pool<> g(1321);
  Kokkos::fill_random(A,g,1.0);
  Kokkos::fill_random(B,g,1.0);

  Kokkos::parallel_for("CreateReference",
    Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0,0}, {C.extent(0), C.extent(1)}),
    KOKKOS_LAMBDA(int i, int j) {
      C2(i,j) = 0;
      for(int k=0; k<A.extent(1); k++) {
        C2(i,j) += A(i,k)*B(k,j);
      }
    }
  );
  gemm(C,A,B);
  int total_errors = 0;
  Kokkos::parallel_reduce("CheckEquivalence",
    Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0,0}, {C.extent(0), C.extent(1)}),
    KOKKOS_LAMBDA(int i, int j, int& errors) {
      if((C(i,j) - C2(i,j))>1e-13) errors++;
      if(i==3 && j==3) printf("%lf %lf\n",C(i,j),C2(i,j));
  },total_errors);
  printf("Errors: %i\n",total_errors);
}

int main(int argc, char* argv[]) {
  Kokkos::initialize(argc,argv);
  {
  int N = 200, M = 57, K=113;
//  int N = 100, M = 100, K=100;

  testgemm<Kokkos::LayoutLeft,Kokkos::LayoutLeft,Kokkos::LayoutLeft>(N,M,K);
  testgemm<Kokkos::LayoutLeft,Kokkos::LayoutRight,Kokkos::LayoutLeft>(N,M,K);
  testgemm<Kokkos::LayoutLeft,Kokkos::LayoutLeft,Kokkos::LayoutRight>(N,M,K);
  testgemm<Kokkos::LayoutLeft,Kokkos::LayoutRight,Kokkos::LayoutRight>(N,M,K);
  testgemm<Kokkos::LayoutRight,Kokkos::LayoutLeft,Kokkos::LayoutLeft>(N,M,K);
  testgemm<Kokkos::LayoutRight,Kokkos::LayoutRight,Kokkos::LayoutLeft>(N,M,K);
  testgemm<Kokkos::LayoutRight,Kokkos::LayoutLeft,Kokkos::LayoutRight>(N,M,K);
  testgemm<Kokkos::LayoutRight,Kokkos::LayoutRight,Kokkos::LayoutRight>(N,M,K);

  }
  Kokkos::finalize();
}

crtrott avatar Feb 15 '22 22:02 crtrott