cython-blis icon indicating copy to clipboard operation
cython-blis copied to clipboard

Non unit-stride ndarrays

Open mratsim opened this issue 6 years ago • 1 comments

Looking into the code I'm pretty sure the code is buggy for non-unit stride ndarrays such as those resulting from slicing, reverse-slicing or broadcasting:

https://github.com/explosion/cython-blis/blob/c5df0793ead2f18127277cef236691e3ad16a9ff/blis/py.pyx#L64-L102

There is no check for row-major inputs but this &A[0,0], A.shape[1], 1 assumes row-major layout.

Instead the code should probably be:


def gemm(const_reals2d_ft A, const_reals2d_ft B,
         np.ndarray out=None, bint trans1=False, bint trans2=False,
         double alpha=1., double beta=1.):
    cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
    cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
    cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
    if const_reals2d_ft is const_float2d_t:
        if out is None:
            out = numpy.zeros((nM, nN), dtype='f')
        C = <float*>out.data
        with nogil:
            cy.gemm(
                cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
                cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
                nM, nN, nK,
                alpha,
                &A[0,0], A.strides[0], A.strides[1],
                &B[0,0], B.strides[0], B.strides[1],
                beta,
                C, out.strides[0], out.strides[1])
        return out
    elif const_reals2d_ft is const_double2d_t:
        if out is None:
            out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
        C = <double*>out.data
        with nogil:
            cy.gemm(
                cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
                cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
                A.shape[0], B.shape[1], A.shape[1],
                alpha,
                &A[0,0], A.strides[0], A.strides[1],
                &B[0,0], B.strides[0], B.strides[1],
                beta,
                C, , out.strides[0], out.strides[1])
        return out
    else:
        C = NULL
        raise TypeError("Unhandled fused type")

same thing for gemv.

This has several advantages:

  • works for any strides
  • faster than default OpenBLAS/MKL as there is no conversion to contiguous array needed.

The main draw of the BLIS API is supporting strided arrays without giving up performance, this is the perfect use-case.

mratsim avatar Sep 15 '19 15:09 mratsim

Thanks! I think you're right, will fix.

honnibal avatar Sep 15 '19 16:09 honnibal