sycl-blas icon indicating copy to clipboard operation
sycl-blas copied to clipboard

Extended Gemm interface to support mixed precision operations

Open OuadiElfarouki opened this issue 1 year ago • 0 comments

Extend Gemm operator interface to support mixed precision operations, namely by decoupling matrix A and B type element_in_t from output matrix C and scalars alpha and beta type element_out_t.

Following oneMKL's spec notation for Gemm API : https://spec.oneapi.io/versions/latest/elements/oneMKL/source/domains/blas/gemm.html#onemkl-blas-gemm, this PR enables (Ta==Tb) to be set independently from (Tc==Ts). This feature has been enabled at a first stage for Ta=Tb=sycl::half and Tc=Ts=float. Thus enabling half support also enables the mixed precision case of (half, float) for gemm.

Changes include:

  • Updating different Gemm kernel implementations to account for the decoupled types.
  • Necessary CMake and Kernel generation scripts updates to account for the couple of types instead of single type in gemm case.
  • Necessary changes to unit-tests to account for this feature in the Gemm case.

Note : Following oneMKL expected Gemm API, Support of bfloat16-float and std::int8_t-float would be straightforward afterwards, but the additional cases of Ta==Tb==Tc while Ts (alpha & beta) is separate will require additional decoupling & re-design work..

OuadiElfarouki avatar Feb 29 '24 13:02 OuadiElfarouki