sum(min_count=1) raises an exception
The first line works, the second raises an exception
import numpy as np
import xarray as xr
import cupy_xarray
xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum().compute()
xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum(min_count=1).compute()
xarray.DataArray'asarray-75d4a7ce4023e88c4c5563214cb235b4'
array(3.)
Coordinates: (0)
Indexes: (0)
Attributes: (0)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[5], line 6
3 import cupy_xarray
5 xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum().compute()
----> 6 xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum(min_count=1).compute()
File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataarray.py:1179](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataarray.py#line=1178), in DataArray.compute(self, **kwargs)
1154 """Manually trigger loading of this array's data from disk or a
1155 remote source into memory and return a new array.
1156
(...)
1176 dask.compute
1177 """
1178 new = self.copy(deep=False)
-> 1179 return new.load(**kwargs)
File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataarray.py:1147](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataarray.py#line=1146), in DataArray.load(self, **kwargs)
1127 def load(self, **kwargs) -> Self:
1128 """Manually trigger loading of this array's data from disk or a
1129 remote source into memory and return this array.
1130
(...)
1145 dask.compute
1146 """
-> 1147 ds = self._to_temp_dataset().load(**kwargs)
1148 new = self._from_temp_dataset(ds)
1149 self._variable = new._variable
File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataset.py:863](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/core/dataset.py#line=862), in Dataset.load(self, **kwargs)
860 chunkmanager = get_chunked_array_type(*lazy_data.values())
862 # evaluate all the chunked arrays simultaneously
--> 863 evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
864 *lazy_data.values(), **kwargs
865 )
867 for k, data in zip(lazy_data, evaluated_data):
868 self.variables[k].data = data
File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/namedarray/daskmanager.py:86](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/xarray/namedarray/daskmanager.py#line=85), in DaskManager.compute(self, *data, **kwargs)
81 def compute(
82 self, *data: Any, **kwargs: Any
83 ) -> tuple[np.ndarray[Any, _DType_co], ...]:
84 from dask.array import compute
---> 86 return compute(*data, **kwargs)
File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/dask/base.py:662](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/dask/base.py#line=661), in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
659 postcomputes.append(x.__dask_postcompute__())
661 with shorten_traceback():
--> 662 results = schedule(dsk, keys, **kwargs)
664 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
File cupy[/_core/core.pyx:1717](http://localhost:8888/_core/core.pyx#line=1716), in cupy._core.core._ndarray_base.__array_function__()
File [~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/cupy/_sorting/search.py:211](http://localhost:8888/lab/tree/icec/seaice/nb/~/mambaforge/envs/cupy-seaice/lib/python3.12/site-packages/cupy/_sorting/search.py#line=210), in where(condition, x, y)
209 if fusion._is_fusing():
210 return fusion._call_ufunc(_where_ufunc, condition, x, y)
--> 211 return _where_ufunc(condition.astype('?'), x, y)
File cupy[/_core/_kernel.pyx:1286](http://localhost:8888/_core/_kernel.pyx#line=1285), in cupy._core._kernel.ufunc.__call__()
File cupy[/_core/_kernel.pyx:159](http://localhost:8888/_core/_kernel.pyx#line=158), in cupy._core._kernel._preprocess_args()
File cupy[/_core/_kernel.pyx:145](http://localhost:8888/_core/_kernel.pyx#line=144), in cupy._core._kernel._preprocess_arg()
TypeError: Unsupported type <class 'numpy.ndarray'>
Versions:
xr.__version__
np.__version__
cupy_xarray.__version__
'2024.6.0'
'1.26.4'
'0.1.3+9.g7fc3df5'
Same thing with numpy 2.0.0
Hi @yt87, thanks for the bug report with the minimal example. I can reproduce the same TypeError on my end locally too.
My initial impression is that this might require some fixes on the dask side, I see some similar issues before, e.g. https://github.com/dask/dask/issues/9315, that might point to some ufunc operations not working with a CuPy backend yet. If I run the following line without dask chunks, it seems to work:
ds = xr.DataArray([1, 2, cupy.nan]).as_cupy().sum(min_count=1)
print(ds)
# <xarray.DataArray ()> Size: 8B
# array(3.)
Do you need to do the sum(min_count=1) operation using dask chunks? If you put the .compute() before .sum(), this would work:
ds = xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().compute()
ds.sum(min_count=1)
Though that assumes that your actual array isn't too large to fit in GPU memory. If it is too large, you might need to parallelize the sum computation without dask by doing it manually yourself as a workaround.
It is np.nan that causes the error:
print(xr.DataArray([1, 2, 3]).chunk(dim_0=1).as_cupy().sum(min_count=1))
<xarray.DataArray 'asarray-60e02971486c2931a91b659a5bdc6e30' ()> Size: 8B
dask.array<sum-aggregate, shape=(), dtype=int64, chunksize=(), chunktype=cupy.ndarray>
print(xr.DataArray([1, 2, np.nan]).chunk(dim_0=1).as_cupy().sum(min_count=1))
<xarray.DataArray 'asarray-75d4a7ce4023e88c4c5563214cb235b4' ()> Size: 8B
dask.array<where, shape=(), dtype=float64, chunksize=(), chunktype=numpy.ndarray>
My use case: I have a large TYX array ~12GB. For some time values, all the data is missing, I want the sum to return nan. When there is some data available, I do want the actual value. Maybe an option is to drop the missing time frames beforehand.
This fix seems to work for me:
File duck_array_ops.py, function as_shared_dtype
# Avoid calling array_type("cupy") repeatidely in the any check
array_type_cupy = array_type("cupy")
# GT fix
import cupy as cp
#if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
if any(isinstance(x, array_type_cupy) or
is_duck_dask_array(x) and type(x._meta) == cp.ndarray
for x in scalars_or_arrays):
#import cupy as cp
xp = cp
elif xp is None:
xp = get_array_namespace(scalars_or_arrays)
What happens is that np.nan is converted to np.ndarray, see my previous message. This causes failure when compute is called, expecting cupy arrays.
This is not a right fix, it makes xarray depend on cupy. There must be a better way.
We do have to handle this in xarrsy. Can you open an issue there please