General oneapi batch_gemm without contiguous constraint
Hi there,
compared to torch.BMM which as far as I can tell has constraints on the arrays A/B/C being contiguous and same size for each entry in the batch, OneMKL has a nice pointer list based gemm_batch calls.
e.g.
oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle,
&iOpA,
&iOpB,
&m64,&n64,&k64,
(float *) &alpha_p[0],
(const float **)&Amk[0], (const int64_t *)&lda64,
(const float **)&Bkn[0], (const int64_t *)&ldb64,
(float *) &beta_p[0],
(float **)&Cmn[0], (const int64_t *)&ldc64,
(int64_t)1,&batchCount64,std::vector<sycl::event>());
In fact you can do multiple groups with different MNK.
HIP and CUDA also support pointer list based cublasSgemmBatched (without the generality of MNK groups), but also lose the constraint of contiguous arrays.
Is there an extension to PyTorch that gives access to this more flexible interface to accept multiple PyTorch tensors in a list, without the need to repack a contiguous ?
Is there a plan?
The Torch BMM that looks to be mapped to oneDNN and other libraries here is:
https://pytorch.org/docs/stable/generated/torch.bmm.html
unless I'm mistaken, this has the contiguous packed tensor constraint.
input and mat2 must be 3-D tensors each containing the same number of matrices.
Hi @paboyle Intel Extension for PyTorch will have interfaces similar to PyTorch. This feature request might be more appropriate for official PyTorch.