xarray icon indicating copy to clipboard operation
xarray copied to clipboard

`sum(..., skipna=True)` should dispatch to `sparse.nansum` for sparse arrays

Open peanutfun opened this issue 2 months ago • 9 comments

What is your issue?

xarray currently uses its own nanops.nansum when calling DataArray.sum(..., skipna=None), which relies on sum_where. This implementation route is very inefficient for sparse arrays, especially (and ironically) when operating on a sparse array with fill_value=np.nan, see https://github.com/pydata/sparse/issues/908. Why doesn't xarray try to dispatch to a possible nansum implementation in the array's namespace?

sparse offers its own nansum. Internally, it also seems to use where, but it's much faster than the xarray nansum. I applied the following patch to duck_array_ops.py, reducing the time for sums on a sparse array significantly:

--- duck_array_ops.py	2025-11-14 12:21:49
+++ duck_array_ops.py	2025-11-14 12:23:20
@@ -519,6 +519,15 @@
 
             nanname = "nan" + name
             func = getattr(nanops, nanname)
+
+            if "min_count" not in kwargs or kwargs["min_count"] is None:
+                try:
+                    kwargs.pop("min_count", None)
+                    xp = get_array_namespace(values)
+                    func = getattr(xp, name)
+                except AttributeError:
+                    pass
+
         else:
             if name in ["sum", "prod"]:
                 kwargs.pop("min_count", None)

Dispatching to sparse.nansum produces a factor 20+ speedup:

# Without patch
$ !python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.sum(dim=['y', 'z'])"
1 loop, best of 5: 36.2 msec per loop

# With patch
$ !python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.sum(dim=['y', 'z'])"
200 loops, best of 5: 1.37 msec per loop

peanutfun avatar Nov 14 '25 11:11 peanutfun

Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!

welcome[bot] avatar Nov 14 '25 11:11 welcome[bot]

the main problem with this is that nan* reductions are not part of the array API (see data-apis/array-api#621), and thus we can't rely on this as far as using the array API is concerned.

If you look at the way this discussion has turned out, it appears like the nan* reductions are seen as a way of doing masked arrays, so the argument is that instead of supporting these it would be better to use a masked array library like marray.

Is there a specific reason for choosing a fill value of nan?

keewis avatar Nov 14 '25 14:11 keewis

I support us attempting to detect if xp.nan* exists, and cache the results of this lookup, because of the perf impact here; and because we choose nan* reductions by default for float dtypes.

dcherian avatar Nov 14 '25 14:11 dcherian

I am not certain, but I believe at some point we tried to allow __array_function__ to overload np.nan* even if the array namespace is defined. I think that's what cubed does now? Until there's a clear decision on the array API discussion I think supporting things that way might be easier to maintain, especially since sparse already supports __array_function__.

keewis avatar Nov 14 '25 14:11 keewis

Yes, __array_function__ would be fine. Indeed I'm surprised we don't do that already? @peanutfun are you able to dig in here?

dcherian avatar Nov 14 '25 15:11 dcherian

I am not sure if I can follow, because the details on __array_function__ and __array_ufunc__ elude me 😬

But here's what happens on the xarray side of things for sum():

  1. Depending on the value of skipna, the wrapper in duck_array_ops calls xarray.nansum or dispatches to xp.sum, where xp is the array namespace: https://github.com/pydata/xarray/blob/6a7a11b5fd45a443d9d27f2d669da0bdd726656a/xarray/core/duck_array_ops.py#L491
  2. xarray.nansum calls xarray.sum_where: https://github.com/pydata/xarray/blob/6a7a11b5fd45a443d9d27f2d669da0bdd726656a/xarray/computation/nanops.py#L97
  3. xarray.sum_where basically calls xp.where and xp.sum

If I understand correctly, to use __array_function__ of sparse, some code in xarray would still need to call xp.nansum on the array, right?

@keewis: I am using NaN as the fill value because I see it as a natural choice for sparse arrays in xarray. Xarray uses NaNs to identify missing values. In my application, I am working with geospatial datasets and odc.geo, which masks data with NaN values. Combining datasets after applying separate masks via fillna or combine_first should be relatively cheap when using NaN as a fill value, right? 🤔

peanutfun avatar Nov 17 '25 10:11 peanutfun

I guess we could try to dispatch nansum using __array_function__, and use sum_where as a fallback?

Combining datasets after applying separate masks via fillna or combine_first should be relatively cheap when using NaN as a fill value, right?

There are some subtleties here: https://github.com/pydata/sparse/issues/871

dcherian avatar Nov 17 '25 15:11 dcherian

The switch from using np.nansum to sum_where happened in https://github.com/pydata/xarray/pull/7067. https://github.com/pydata/xarray/blob/33ce95e30ecf8ce4b367401ea3b52b4b748e2f6f/xarray/core/duck_array_ops.py#L389-L396

From a quick glance the sparse implementation of nansum doesn't look that different to sum_where though? https://github.com/pydata/sparse/blob/aea60066b621afecebac08df734a5e503fc3e4b7/sparse/numba_backend/_coo/common.py#L719

Illviljan avatar Nov 17 '25 18:11 Illviljan

@Illviljan I tried to dig into this a bit more and noticed I made an embarrassing mistake: Instead of dispatching to nanname I dispatched to name in the above patch, meaning that I compared the timings for xarray.nansum and sparse.sum. When dispatching to sparse.nansum, the difference in timings is far less critical, but still noticeable (36ms vs. 29 ms). I am not sure if this warrants a change in the xarray code base still? See below for the timings and profiler output (for taking the sum 100 times) in both cases.

Without patch

$ !python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.sum(dim=['y', 'z'])"
1 loop, best of 5: 36.2 msec per loop
Image

With patch

$ !python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.sum(dim=['y', 'z'])"
1 loop, best of 5: 29 msec per loop
Image

Fixed patch

--- duck_array_ops.py	2025-11-14 12:21:49
+++ duck_array_ops.py	2025-11-14 12:23:20
@@ -519,6 +519,15 @@
 
             nanname = "nan" + name
             func = getattr(nanops, nanname)
+
+            if "min_count" not in kwargs or kwargs["min_count"] is None:
+                try:
+                    kwargs.pop("min_count", None)
+                    xp = get_array_namespace(values)
+                    func = getattr(xp, nanname)  # <-- Fixed
+                except AttributeError:
+                    pass
+
         else:
             if name in ["sum", "prod"]:
                 kwargs.pop("min_count", None)

peanutfun avatar Nov 18 '25 15:11 peanutfun