sparse icon indicating copy to clipboard operation
sparse copied to clipboard

Enh: Accept "duck arrays" for `tensordot`

Open brendan-m-murphy opened this issue 1 year ago • 8 comments

Please describe the purpose of the new feature or describe the problem to solve.

Sparse's tensordot only allows multiplication between sparse arrays and either scipy sparse arrays or numpy ndarrays.

It would be useful if other array-like objects were allowed.

For instance, in xarray, the dot function can only multiply a sparse DataArray and a dask DataArray if the einsum/tensordot function from dask is used: https://github.com/pydata/xarray/issues/9934

Suggest a solution if possible.

The code for multiplying a COO matrix and a np.ndarray in _dot seems like it mostly relies on being able to infer the dtype, index, and create empty ndarrays, so it seems plausible that other array-like objects could be used here.

I haven't tried to implement this though.

If you have tried alternatives, please describe them below.

No response

Additional information that may help us understand your needs.

Please see this issue for further discussion in the context of xarray: https://github.com/pydata/xarray/issues/9934

brendan-m-murphy avatar Jan 09 '25 17:01 brendan-m-murphy

Unfortunately; it actually uses Numba under the hood (which only accepts NumPy arrays) to do the actual computation, doing this with Dask (or even NumPy without Numba) would be excruciatingly slow.

However, we do support NumPy arrays inside the @ operator -- on either side. It seems to me that XArray is having trouble detecting this.

>>> import numpy as np; import sparse
>>> sp_arr = sparse.zeros((5, 5), dtype=sparse.float32)
>>> np_arr = np.zeros((5, 5), dtype=np.float32)
>>> sp_arr @ np_arr
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)
>>> np_arr @ sp_arr
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]], dtype=float32)

I'd be happy to hear an alternative solution that would help XArray/this use-case but supporting duck arrays in @ is beyond the scope of this library.

hameerabbasi avatar Jan 10 '25 12:01 hameerabbasi

Thanks, I suspected it might not be so simple.

I think Dask is using sp_arr @ np_arr on each chunk, and I believe @ works properly if the underlying array is a np.ndarray rather than a Dask array.

The linked xarray issue has some workarounds to get Dask to initiate the multiplication, rather than sparse, so I will go back and ask if any of those could be incorporated into xarray.

brendan-m-murphy avatar Jan 15 '25 09:01 brendan-m-murphy

if the variant that works is implemented by dask, wouldn't it be fine to have sparse return NotImplemented and have it use dask.array.Array.__rmatmul__ instead?

keewis avatar Jan 15 '25 13:01 keewis

That we can do -- can you open an issue for that?

hameerabbasi avatar Jan 15 '25 13:01 hameerabbasi

sorry, looks like that already happens (I think)? See https://github.com/pydata/sparse/blob/2bca00c37654f8e1116dd5c988779f68665479de/sparse/numba_backend/_coo/core.py#L958-L972

I'll have to look into why that is not triggered.

Edit: maybe because the error is TypeError and not NotImplementedError? Edit: or maybe because we're calling tensordot and not matmul?

keewis avatar Jan 15 '25 13:01 keewis

sparse @ dask (in xarray) throws a TypeError from sparse's _dot, I believe. If you're using xarray with opt-einsum, then it selects sparse.tensordot as the method to use for sparse @ dask (i.e. xr.dot(sparse, dask)).

brendan-m-murphy avatar Jan 15 '25 13:01 brendan-m-murphy

Right -- I just pushed a fix/release for __matmul__ and __rmatmul__, but functions should indeed return TypeError, which is correct.

hameerabbasi avatar Jan 15 '25 14:01 hameerabbasi

yeah, forwarding to __rmatmul__ only works if we use sparse @ dask, whereas xr.dot uses something else. So we'll have to figure out how sparse can reject unknown array types while still allowing other array types to implement the operation.

Tensordot can be dispatched through various numpy protocols, where I believe for __array_function__ this is possible by returning NotImplemented, but __array_namespace__ explicitly does not want to allow the interaction of multiple array types. This may mean that the calling code (in this context probably xarray) has to cast the data to the appropriate type.

keewis avatar Jan 15 '25 14:01 keewis