csr icon indicating copy to clipboard operation
csr copied to clipboard

Multiplying a complex matrix with a complex vector raises an error

Open christian-cahig opened this issue 3 years ago • 3 comments

Hi, I was playing with CSR as a prospect tool for using sparse matrices in Numba. In my use case, multiplication of a complex matrix and a complex vector is common, e.g.,

import numpy as np, scipy as sp
import csr

# Complex matrix
A_ = sp.sparse.random(5, 5, density=0.1, format='csr') + 1j*sp.sparse.random(5, 5, density=0.1, format='csr')
A = csr.create(A_.shape[0], A_.shape[1], A_.nnz, A_.indptr, A_.indices, A_.data)

# Complex vector
b = np.random.random(5) + 1j*np.random.random(5)

# Expected output is a complex vector
A.mult_vec(b)

However, the last line in the above snippet raises a TypingError, i.e.,

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(array(float64, 1d, C), int64, complex128)
 
There are 16 candidate implementations:
      - Of which 16 did not match due to:
      Overload of function 'setitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 1d, C), int64, complex128)':
       No match.

During: typing of setitem at absolute\path\to\my\conda-env\lib\site-packages\csr\kernels\numba\__init__.py (62)

File "relative\path\to\my\conda-env\lib\site-packages\csr\kernels\numba\__init__.py", line 62:
def mult_vec(h: CSR, v):
    <source elided>
        col = h.colinds[i]
        res[row] += v[col] * h._e_value(i)
        ^

christian-cahig avatar Dec 19 '21 01:12 christian-cahig