stdBLAS
stdBLAS copied to clipboard
GEMM Strides/Extents to CBLAS mapping
Does this seem right? Posting for each matrix (extent(0),extent(1)), (stride(0),stride(1))
| C | A | B | defacto layouts | CBLAS call |
|---|---|---|---|---|
| (M,N), (1,M) | (M,K), (1,M) | (K,N), (1,K) | (left,left,left) | gemm('N','N', M, N, K, 1., A.data(), K, B.data(), N, 1., C.data(), N) |
| (M,N), (1,M) | (M,K), (K,1) | (K,N), (1,K) | (left,right,left) | gemm('T','N', N, M, K, 1., A.data(), M, B.data(), N, 1., C.data(), N) |
| (M,N), (1,M) | (M,K), (1,M) | (K,N), (N,1) | (left,left,right) | gemm('N','T', M, N, K, 1., A.data(), K, B.data(), K, 1., C.data(), N) |
| (M,N), (1,M) | (M,K), (K,1) | (K,N), (N,1) | (left,right,right) | gemm('T','T', M, N, K, 1., A.data(), M, B.data(), K, 1., C.data(), N) |
| (M,N), (N,1) | (M,K), (1, M) | (K,N), (1,K) | (right,left,left) | gemm('T','T', N, M, K, 1., B.data(), N, A.data(), K, 1., C.data(), M) |
| (M,N), (N,1) | (M,K), (K, 1) | (K,N), (1,K) | (right,right,left) | gemm('T','N', N, M, K, 1., B.data(), N, A.data(), M, 1., C.data(), M) |
| (M,N), (N,1) | (M,K), (1, M) | (K,N), (N,1) | (right,left,right) | gemm('N','T', N, M, K, 1., B.data(), K, A.data(), K, 1., C.data(), M) |
| (M,N), (N,1) | (M,K), (K,1) | (K,N), (N,1) | (right,right,right) | gemm('N','N', N, M, K, 1., B.data(), K, A.data(), M, 1., C.data(), M) |
Here is the test code:
#include<Kokkos_Core.hpp>
#include<Kokkos_Random.hpp>
extern "C" void dgemm_(const char*, const char*, int*, int*, int*, double*, double*, int*, double*, int*, double*, double*, int*);
template<class AT, class BT, class CT>
void gemm(CT C, AT A, BT B) {
// printf("C: %i %i %i %i\n",C.extent_int(0),C.extent_int(1),int(C.stride(0)),int(C.stride(1)));
// printf("A: %i %i %i %i\n",A.extent_int(0),A.extent_int(1),int(A.stride(0)),int(A.stride(1)));
// printf("B: %i %i %i %i\n",B.extent_int(0),B.extent_int(1),int(B.stride(0)),int(B.stride(1)));
int M = C.extent(0);
int N = C.extent(1);
int K = A.extent(1);
int LDA = A.stride(0)==1?A.extent(0):A.extent(1);
int LDB = B.stride(0)==1?B.extent(0):B.extent(1);
int LDC = C.stride(0)==1?C.extent(0):C.extent(1);
double alpha = 1., beta = 0.;
double* A_data = A.data();
double* B_data = B.data();
double* C_data = C.data();
if(C.stride(0)==1) {
if(A.stride(0)==1 && B.stride(0)==1)
dgemm_("N","N",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
if(A.stride(1)==1 && B.stride(0)==1)
dgemm_("T","N",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
if(A.stride(0)==1 && B.stride(1)==1)
dgemm_("N","T",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
if(A.stride(1)==1 && B.stride(1)==1)
dgemm_("T","T",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
} else if(C.stride(1)==1) {
if(A.stride(0)==1 && B.stride(0)==1)
dgemm_("T","T",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
if(A.stride(1)==1 && B.stride(0)==1)
dgemm_("T","N",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
if(A.stride(0)==1 && B.stride(1)==1)
dgemm_("N","T",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
if(A.stride(1)==1 && B.stride(1)==1)
dgemm_("N","N",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
}
}
template<class LC, class LA, class LB>
void testgemm(int M, int N, int K) {
Kokkos::View<double**,LA> A("A",M,K);
Kokkos::View<double**,LB> B("B",K,N);
Kokkos::View<double**,LC> C("C",M,N),C2("C2",M,N);
Kokkos::Random_XorShift64_Pool<> g(1321);
Kokkos::fill_random(A,g,1.0);
Kokkos::fill_random(B,g,1.0);
Kokkos::parallel_for("CreateReference",
Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0,0}, {C.extent(0), C.extent(1)}),
KOKKOS_LAMBDA(int i, int j) {
C2(i,j) = 0;
for(int k=0; k<A.extent(1); k++) {
C2(i,j) += A(i,k)*B(k,j);
}
}
);
gemm(C,A,B);
int total_errors = 0;
Kokkos::parallel_reduce("CheckEquivalence",
Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0,0}, {C.extent(0), C.extent(1)}),
KOKKOS_LAMBDA(int i, int j, int& errors) {
if((C(i,j) - C2(i,j))>1e-13) errors++;
if(i==3 && j==3) printf("%lf %lf\n",C(i,j),C2(i,j));
},total_errors);
printf("Errors: %i\n",total_errors);
}
int main(int argc, char* argv[]) {
Kokkos::initialize(argc,argv);
{
int N = 200, M = 57, K=113;
// int N = 100, M = 100, K=100;
testgemm<Kokkos::LayoutLeft,Kokkos::LayoutLeft,Kokkos::LayoutLeft>(N,M,K);
testgemm<Kokkos::LayoutLeft,Kokkos::LayoutRight,Kokkos::LayoutLeft>(N,M,K);
testgemm<Kokkos::LayoutLeft,Kokkos::LayoutLeft,Kokkos::LayoutRight>(N,M,K);
testgemm<Kokkos::LayoutLeft,Kokkos::LayoutRight,Kokkos::LayoutRight>(N,M,K);
testgemm<Kokkos::LayoutRight,Kokkos::LayoutLeft,Kokkos::LayoutLeft>(N,M,K);
testgemm<Kokkos::LayoutRight,Kokkos::LayoutRight,Kokkos::LayoutLeft>(N,M,K);
testgemm<Kokkos::LayoutRight,Kokkos::LayoutLeft,Kokkos::LayoutRight>(N,M,K);
testgemm<Kokkos::LayoutRight,Kokkos::LayoutRight,Kokkos::LayoutRight>(N,M,K);
}
Kokkos::finalize();
}