cython-blis
cython-blis copied to clipboard
Non unit-stride ndarrays
Looking into the code I'm pretty sure the code is buggy for non-unit stride ndarrays such as those resulting from slicing, reverse-slicing or broadcasting:
https://github.com/explosion/cython-blis/blob/c5df0793ead2f18127277cef236691e3ad16a9ff/blis/py.pyx#L64-L102
There is no check for row-major inputs but this &A[0,0], A.shape[1], 1 assumes row-major layout.
Instead the code should probably be:
def gemm(const_reals2d_ft A, const_reals2d_ft B,
np.ndarray out=None, bint trans1=False, bint trans2=False,
double alpha=1., double beta=1.):
cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
if const_reals2d_ft is const_float2d_t:
if out is None:
out = numpy.zeros((nM, nN), dtype='f')
C = <float*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
nM, nN, nK,
alpha,
&A[0,0], A.strides[0], A.strides[1],
&B[0,0], B.strides[0], B.strides[1],
beta,
C, out.strides[0], out.strides[1])
return out
elif const_reals2d_ft is const_double2d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
C = <double*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
A.shape[0], B.shape[1], A.shape[1],
alpha,
&A[0,0], A.strides[0], A.strides[1],
&B[0,0], B.strides[0], B.strides[1],
beta,
C, , out.strides[0], out.strides[1])
return out
else:
C = NULL
raise TypeError("Unhandled fused type")
same thing for gemv.
This has several advantages:
- works for any strides
- faster than default OpenBLAS/MKL as there is no conversion to contiguous array needed.
The main draw of the BLIS API is supporting strided arrays without giving up performance, this is the perfect use-case.
Thanks! I think you're right, will fix.