`as_shared_dtype` converts scalars to 0d `numpy` arrays if chunked `cupy` is involved
I tried to run where with chunked cupy arrays:
In [1]: import xarray as xr
...: import cupy
...: import dask.array as da
...:
...: arr = xr.DataArray(cupy.arange(4), dims="x")
...: mask = xr.DataArray(cupy.array([False, True, True, False]), dims="x")
this works:
In [2]: arr.where(mask)
Out[2]:
<xarray.DataArray (x: 4)>
array([nan, 1., 2., nan])
Dimensions without coordinates: x
this fails:
In [4]: arr.chunk().where(mask).compute()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 arr.chunk().where(mask).compute()
File ~/repos/xarray/xarray/core/dataarray.py:1095, in DataArray.compute(self, **kwargs)
1076 """Manually trigger loading of this array's data from disk or a
1077 remote source into memory and return a new array. The original is
1078 left unaltered.
(...)
1092 dask.compute
1093 """
1094 new = self.copy(deep=False)
-> 1095 return new.load(**kwargs)
File ~/repos/xarray/xarray/core/dataarray.py:1069, in DataArray.load(self, **kwargs)
1051 def load(self: T_DataArray, **kwargs) -> T_DataArray:
1052 """Manually trigger loading of this array's data from disk or a
1053 remote source into memory and return this array.
1054
(...)
1067 dask.compute
1068 """
-> 1069 ds = self._to_temp_dataset().load(**kwargs)
1070 new = self._from_temp_dataset(ds)
1071 self._variable = new._variable
File ~/repos/xarray/xarray/core/dataset.py:752, in Dataset.load(self, **kwargs)
749 import dask.array as da
751 # evaluate all the dask arrays simultaneously
--> 752 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
754 for k, data in zip(lazy_data, evaluated_data):
755 self.variables[k].data = data
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/base.py:600, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
597 keys.append(x.__dask_keys__())
598 postcomputes.append(x.__dask_postcompute__())
--> 600 results = schedule(dsk, keys, **kwargs)
601 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
86 elif isinstance(pool, multiprocessing.pool.Pool):
87 pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
90 pool.submit,
91 pool._max_workers,
92 dsk,
93 keys,
94 cache=cache,
95 get_id=_thread_get_id,
96 pack_exception=pack_exception,
97 **kwargs,
98 )
100 # Cleanup pools associated to dead threads
101 with pools_lock:
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
509 _execute_task(task, data) # Re-execute locally
510 else:
--> 511 raise_exception(exc, tb)
512 res, worker_id = loads(res_info)
513 state["cache"][key] = res
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:319, in reraise(exc, tb)
317 if exc.__traceback__ is not tb:
318 raise exc.with_traceback(tb)
--> 319 raise exc
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
222 try:
223 task, data = loads(task_info)
--> 224 result = _execute_task(task, data)
225 id = get_id()
226 result = dumps((result, id))
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
988 if not len(args) == len(self.inkeys):
989 raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:149, in get(dsk, out, cache)
147 for key in toposort(dsk):
148 task = dsk[key]
--> 149 result = _execute_task(task, cache)
150 cache[key] = result
151 result = _execute_task(out, cache)
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File <__array_function__ internals>:180, in where(*args, **kwargs)
File cupy/_core/core.pyx:1723, in cupy._core.core._ndarray_base.__array_function__()
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/cupy/_sorting/search.py:211, in where(condition, x, y)
209 if fusion._is_fusing():
210 return fusion._call_ufunc(_where_ufunc, condition, x, y)
--> 211 return _where_ufunc(condition.astype('?'), x, y)
File cupy/_core/_kernel.pyx:1287, in cupy._core._kernel.ufunc.__call__()
File cupy/_core/_kernel.pyx:160, in cupy._core._kernel._preprocess_args()
File cupy/_core/_kernel.pyx:146, in cupy._core._kernel._preprocess_arg()
TypeError: Unsupported type <class 'numpy.ndarray'>
this works again:
In [7]: arr.chunk().where(mask.chunk(), cupy.array(cupy.nan)).compute()
Out[7]:
<xarray.DataArray (x: 4)>
array([nan, 1., 2., nan])
Dimensions without coordinates: x
And other methods like fillna show similar behavior.
I think the reason is that this: https://github.com/pydata/xarray/blob/d4db16699f30ad1dc3e6861601247abf4ac96567/xarray/core/duck_array_ops.py#L195 is not sufficient to detect cupy beneath other layers of duckarrays (most commonly dask, pint, or both). In this specific case we could extend the condition to also match chunked cupy arrays (like arr.cupy.is_cupy does, but using is_duck_dask_array), but this will still break for other duckarray layers or if dask is not involved, and we're also in the process of moving away from special-casing dask. So short of asking cupy to treat 0d arrays like scalars I'm not sure how to fix this.
cc @jacobtomlinson
Ping @leofang in case you have thoughts?
Sorry that I missed the ping, Jacob, but I'd need more context for making any suggestions/answers 😅 Is the question about why CuPy wouldn't return scalars?
The issue is that here: https://github.com/pydata/xarray/blob/d4db16699f30ad1dc3e6861601247abf4ac96567/xarray/core/duck_array_ops.py#L193-L206 we try to convert everything to the same dtype, casting numpy and python scalars to an array. The latter is important, because e.g. numpy.array_api.where only accepts arrays as input.
However, detecting cupy beneath (multiple) layers of duckarrays is not easy, which means that for example passing a pint(dask(cupy)) array together with scalars will currently cast the scalars to 0-d numpy arrays, while passing a cupy array instead will result in 0-d cupy arrays.
My naive suggestion was to treat np.int64(0) and np.array(0, dtype="int64") the same, where at the moment the latter would fail for the same reason as np.array([0], dtype="int64").
Thanks, Justus, for expanding on this. It sounds to me the question is "how do we cast dtypes when multiple array libraries are participating in the same computation?" and I am not sure I am knowledgable enough to make any comment.
From the array API point of view, long long ago we decided that this is UB (undefined behavior), meaning it's completely up to each library to decide what to do. You can raise or come up with a special rule that you can make sense of.
It sounds like Xarray has some machinery to deal with this situation, but you'd rather prefer to not keep special-casing for a certain array library? Am I understanding it right?
there's two things that happen in as_shared_dtype (which may not be good design, and we should probably consider splitting it into as_shared_dtype and as_compatible_arrays or something): first, we cast everything to an array, then decide on a common dtype and cast everything to that.
The latter could easily be done by using numpy scalars, which as far as I can tell would be supported by most array libraries, including cupy. However, the reason we need to cast to arrays is that the array API (i.e. __array_namespace__) does not allow using scalars of any type, e.g. np.array_api.where (this is important for libraries that don't implement __array_ufunc__ / __array_function__). To clarify, what we're trying to support is something like
import numpy.array_api as np
np.where(cond, cupy_array, python_scalar)
which (intentionally?) does not work.
At the moment, as_shared_dtype (or, really, the hypothetical as_compatible_arrays) correctly casts python_scalar to a 0-d cupy.array for the example above, but if we were to replace cupy_array with chunked_cupy_array or chunked_cupy_array_with_units, the special casing for cupy stops to work and scalars will be cast to 0-d numpy.array. Conceptually, I tend to think of 0-d arrays as equivalent to scalars, hence the suggestion to have cupy treat numpy scalars and 0-d numpy.array the same way (I don't follow the array api closely enough to know whether that was already discussed and rejected).
So really, my question is: how do we support python scalars for libraries that only implement __array_namespace__, given that stopping to do so would be a major breaking change?
Of course, I would prefer removing the special casing for specific libraries, but I wouldn't be opposed to keeping the existing one. I guess as a short-term fix we could just pull _meta out of duck dask arrays and determine the common array type for that (the downside is that we'd add another special case for dask, which in another PR we're actually trying to remove).
As a long-term fix I guess we'd need to revive the stalled nested duck array discussion.
So really, my question is: how do we support python scalars for libraries that only implement
__array_namespace__, given that stopping to do so would be a major breaking change?
I was considering this question for SciPy (xref scipy#18286) this week, and I think I'm happy with this strategy:
- Cast all "array-like" inputs like Python scalars, lists/sequences, and generators, to
numpy.ndarray. - Require "same array type" input, forbid mixing numpy-cupy, numpy-pytorch, cupy-pytorch, etc. - this will raise an exception
- As a result, cupy-pyscalar and pytorch-pyscalar will also raise an exception.
What that results in is an API that's backwards-compatible for numpy and array-like usage, and much stricter when using other array libraries. That strictness to me is a good thing, because:
- that's what CuPy, PyTorch & co themselves do, and it works well there
- it avoids the complexity raised by arbitrary mixing, which results in questions like the one raised in this issue.
- in case you do need to use a scalar from within a function inside your own library, just convert it explicitly to the desired array type with
xp.asarray(a_scalar)giving you a 0-D array of the correct type (adddtype=x.dtypeto make sure dtypes match if that matters)
- in case you do need to use a scalar from within a function inside your own library, just convert it explicitly to the desired array type with
So, after thinking about this for (quite) some time, it appears that one way or another we need to figure out the appropriate base array type of the nested array (regardless of whether or not we disallow passing python scalars to the xarray API... though since it is a breaking change I don't think we will do that).
I've come up with a (recursive) way of extracting the nesting structure in keewis/nested-duck-arrays, which we should be able to use to figure out the leaf array type and keep the current hack until we figure out how to resolve the issue without it.
Would this be an acceptable, if temporary, fix for #9195? Modified code in as_shared_dtype:
array_type_cupy = array_type("cupy")
# temporary fix
import nested_duck_arrays.dask
def _maybe_cupy(seq):
return any(isinstance(x, array_type_cupy) or
is_duck_dask_array(x) and x.__duck_arrays__()[-1].__module__ == 'cupy'
for x in seq)
# if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
if _maybe_cupy(scalars_or_arrays):
# end of fix
import cupy as cp
I'd go with something like
import nested_duck_arrays.dask
import nested_duck_arrays
...
if any(nested_duck_arrays.first_layer(x) is array_type_cupy for x in scalars_or_arrays):
import cupy as cp
and add nested_duck_arrays.first_layer (with maybe a better name?) which would have a fallback of returning a 1-tuple containing type of x in case x is not a duck array (I'd be happy to relatively quickly release that to PyPI / conda-forge).
We'll need to think about what to do if nested_duck_arrays is not installed, though... something like this, maybe?
try:
from nested_duck_arrays import first_layer
except ImportError:
def first_layer(x):
return type(x)
Also, we'll probably want to push the contents of nested_duck_arrays.dask to dask.array.