oneMKL icon indicating copy to clipboard operation
oneMKL copied to clipboard

const scaling paramters not supported for gemm_batch

Open AidanBeltonS opened this issue 1 year ago • 1 comments

Summary

gemm_batch and possibly other batch methods do not allow const Ts* types for the alpha and beta scaling parameters. I believe this should be the case as it is documented within the oneMKL spec. I also believe this to be a reasonable parameter argument as alpha and beta are read only data types and therefore should not be modified. My reference: https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-dpcpp/2023-0/gemm-batch.html

Version

oneMKL hash: 7d2044e202dbc67ff4eee598a4392edcd163deaf

Environment

oneMKL works with multiple HW and backend libraries and also depends on the compiler and build environment. Include the following information to help reproduce the issue:

  • HW: A100 GPU
  • Backend: cuBlas
  • OS: Ubuntu 20.04
  • Compiler version: DPC++ 2024.0.2

Steps to reproduce

Compile with for NVidia GPUs: icpx -fsycl -fsycl-targets=nvptx64-nvidia-cuda reproducer_onemkl_batch.cpp -lonemkl or for Intel GPUs: icpx -fsycl reproducer_onemkl_batch.cpp -lonemkl

#include <sycl/sycl.hpp>
#include <oneapi/mkl.hpp>

template <class Ta, class Tb, class Tc, class Ts> 
void run_gemm(sycl::queue q) {
    // Construct some arbitrary data, error is in compilation, so it does not have to be correct.
    const Ta *a[4] = {nullptr};
    const Tb *b[4] = {nullptr};
    Tc *c[4] = {nullptr};

    int64_t batch_size = 4;

    oneapi::mkl::transpose a_trans = oneapi::mkl::transpose::trans;
    oneapi::mkl::transpose b_trans = oneapi::mkl::transpose::nontrans;

    int64_t m = 10; 
    int64_t n = 10; 
    int64_t k = 10; 

    int64_t lda = 10; 
    int64_t ldb = 10; 
    int64_t ldc = 10; 

    int64_t group_size = 1;

    Ts alpha = 1;
    Ts beta = 0;
    oneapi::mkl::transpose *trans =
        reinterpret_cast<oneapi::mkl::transpose *>( 
            std::malloc(sizeof(oneapi::mkl::transpose) * 2 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      trans[batch + batch_size * 0] = a_trans;
      trans[batch + batch_size * 1] = b_trans;
    }   

    // structured m, n, k, lda, ldb, ldc, group_size
    int64_t *dims = reinterpret_cast<int64_t *>( 
        std::malloc(sizeof(int64_t) * 7 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      dims[batch + batch_size * 0] = m;
      dims[batch + batch_size * 1] = n;
      dims[batch + batch_size * 2] = k;

      dims[batch + batch_size * 3] = lda;
      dims[batch + batch_size * 4] = ldb;
      dims[batch + batch_size * 5] = ldc;

      dims[batch + batch_size * 6] = group_size;
    }   

    // structured alpha, beta
    Ts *coeff =
        reinterpret_cast<Ts *>(std::malloc(sizeof(Ts) * 2 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      coeff[batch + batch_size * 0] = 1;
      coeff[batch + batch_size * 1] = 0;
    }


    oneapi::mkl::blas::column_major::gemm_batch(
        q, trans + batch_size * 0 /*a_trans*/,
        trans + batch_size * 1 /*b_trans*/, dims + batch_size * 0 /*m*/,
        dims + batch_size * 1 /*n*/, dims + batch_size * 2 /*k*/,
        reinterpret_cast<const Ts*>(coeff + batch_size * 0) /*alpha*/,
        reinterpret_cast<const Ta **>(a), dims + batch_size * 3 /*lda*/,
        reinterpret_cast<const Tb **>(b), dims + batch_size * 4 /*ldb*/,
        reinterpret_cast<const Ts*>(coeff + batch_size * 1) /*beta*/, reinterpret_cast<Tc **>(c),
        dims + batch_size * 5 /*ldc*/, batch_size,
        dims + batch_size * 6 /*group_size*/);
}

int main() {
    sycl::queue q;
    run_gemm<float, float, float, float>(q);
}

Error:

reproducer_onemkl_batch.cpp:60:5: error: no matching function for call to 'gemm_batch'
   60 |     oneapi::mkl::blas::column_major::gemm_batch(
      |     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
reproducer_onemkl_batch.cpp:74:5: note: in instantiation of function template specialization 'run_gemm<float, float, float, float>' requested here
   74 |     run_gemm<float, float, float, float>(q);
      |     ^
/home/aidanbelton/source/oneMKL/include/oneapi/mkl/blas.hxx:2186:27: note: candidate function not viable: 7th argument ('const float *') would lose const qualifier
 2188 | static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa,
      |                           ^
 2189 |                                          transpose *transb, std::int64_t *m, std::int64_t *n,
 2190 |                                          std::int64_t *k, float *alpha, const float **a,

I would expect this to compile based upon the documentation linked and the fact the parameter is read only

AidanBeltonS avatar Feb 15 '24 15:02 AidanBeltonS

@AidanBeltonS This looks like a real gap. Thanks for reporting this. We will take a look. For non-array parameters, we typically don't use const since they are passed by value. For this particular case, we are looking at gemm_batch GROUP API where all parameters are arrays passed with const

mmeterel avatar Feb 15 '24 17:02 mmeterel