Fix upcasting with python builtin numbers and numpy 2
See #8402 for more discussion. Bottom line is that numpy 2 changes the rules for casting between two inputs. Due to this and xarray's preference for promoting python scalars to 0d arrays (scalar arrays), xarray objects are being upcast to higher data types when they previously didn't.
I'm mainly opening this PR for further and more detailed discussion.
CC @dcherian
- [ ] Closes #8402
- [ ] Tests added
- [ ] User visible changes (including notable bug fixes) are documented in
whats-new.rst - [ ] New functions/methods are listed in
api.rst
Ugh my local clone was so old it was pointing to master. One sec...
Ok so the failing test is the array-api version (https://github.com/data-apis/array-api-compat) where it expects both the x and y inputs of the where function to be .dtype. Since we're skipping scalar->array conversion in this PR those objects won't have a .dtype. I'm not sure what the rules are for the strict array API having scalar inputs.
Looks like the array api strictly wants arrays: https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html
Related but I don't fully understand it: https://github.com/data-apis/array-api-compat/issues/85
I guess it depends how you interpret the array API standard then. I can file an issue if needed. To me, depending on how you read the standard, it means either:
- This test is flawed as it tests scalar inputs when the array API specifically defines Array inputs.
- The Array API package is flawed because it assumes and requires Array inputs when the standard allows for scalar inputs (I don't think this is true if I'm understanding the description).
The other point is that maybe numpy compatibility is more important until numpy more formally conforms to the array API standard (see the first note on https://data-apis.org/array-api/latest/API_specification/array_object.html#api-specification-array-object--page-root). But also type promotion seems wishy-washy and not super strict: https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
I propose, because it works best for me and matches numpy compatibility, that I update the test to have a numpy case only but add a new test function with numpy and array api cases with array inputs to .where instead of scalars.
I lean towards (1).
I looked at this for a while, and we'll need major changes around handling array API dtype objects to do this properly.
cc @keewis
we'll need major changes around handling array API dtype objects to do this properly.
I think the change could be limited to xarray.core.duck_array_ops.as_shared_dtype. According to the Array API section on mixing scalars and arrays, we should to use the dtype of the array (though it only looks at scalar + 1 array, so we'd need to extend that).
However, what we currently do is cast all scalars to arrays using asarray, which means python scalars use the OS default dtype (e.g. float64 on most 64-bit systems).
As a algorithm, maybe this could work:
- separate the input into python scalars and arrays / scalars with dtype
- determine
result_typeusing just the arrays / scalars with dtype - check that all python scalars are compatible with the result (otherwise might have to return
object?) - cast all input to arrays with the dtype
According to the Array API section on mixing scalars and arrays, we should to use the dtype of the array (though it only looks at scalar + 1 array, so we'd need to extend that).
Do you know if this is inline with numpy 2 dtype casting behavior?
The main numpy namespace is supposed to be Array API compatible, so it should? I don't know for certain, though.
check that all python scalars are compatible with the result (otherwise might have to return object?)
How do we check this?
Here's what I have locally which seems to pass:
Subject: [PATCH] Cast scalars as arrays with result type of only arrays
---
Index: xarray/core/duck_array_ops.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
--- a/xarray/core/duck_array_ops.py (revision e27f572585a6386729a5523c1f9082c72fa8d178)
+++ b/xarray/core/duck_array_ops.py (date 1713816523554)
@@ -239,20 +239,30 @@
import cupy as cp
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
+ # Pass arrays directly instead of dtypes to result_type so scalars
+ # get handled properly.
+ # Note that result_type() safely gets the dtype from dask arrays without
+ # evaluating them.
+ out_type = dtypes.result_type(*arrays)
else:
- arrays = [
- # https://github.com/pydata/xarray/issues/8402
- # https://github.com/pydata/xarray/issues/7721
- x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
- for x in scalars_or_arrays
- ]
- # Pass arrays directly instead of dtypes to result_type so scalars
- # get handled properly.
- # Note that result_type() safely gets the dtype from dask arrays without
- # evaluating them.
- out_type = dtypes.result_type(*arrays)
+ # arrays = [
+ # # https://github.com/pydata/xarray/issues/8402
+ # # https://github.com/pydata/xarray/issues/7721
+ # x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
+ # for x in scalars_or_arrays
+ # ]
+ objs_with_dtype = [obj for obj in scalars_or_arrays if hasattr(obj, "dtype")]
+ if objs_with_dtype:
+ # Pass arrays directly instead of dtypes to result_type so scalars
+ # get handled properly.
+ # Note that result_type() safely gets the dtype from dask arrays without
+ # evaluating them.
+ out_type = dtypes.result_type(*objs_with_dtype)
+ else:
+ out_type = dtypes.result_type(*scalars_or_arrays)
+ arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
return [
- astype(x, out_type, copy=False) if hasattr(x, "dtype") else x for x in arrays
+ astype(x, out_type, copy=False) for x in arrays
]
I just through it together to see if it would work. I'm not sure it is accurate, but the fact that it is almost exactly like the existing solution with the only difference being the out_type = changes makes me feel this is going in a good direction.
Note I had to do if objs_with_dtype: because the test passes two python scalars so there are no arrays to determine the result type.
How do we check this?
Not sure... but there are only so many builtin types that can be involved without requiring object dtype, so we could just enumerate all of them? As far as I can tell, that would be: bool, int, float, str, datetime/date, and timedelta
check that all python scalars are compatible with the result (otherwise might have to return object?)
How do we check this?
@keewis Do you have a test that I can add to verify any fix I attempt for this? What do you mean by python scalar being compatible with the result?
well, for example, what should happen for this:
a = xr.DataArray(np.array([1, 2, 3], dtype="int8"), dim="x")
xr.where(a % 2 == 1, a, 1.2)
according to the algorithm above, we have one array of dtype int8, so that means we'd have to check if 1.2 (a float) is compatible with int8. It is not, so we should promote everything to float (the default would be to use float64, which might be a bit weird).
Something similar:
a = xr.DataArray(np.array(["2019-01-01", "2020-01-01"], dtype="datetime64[ns]"), dim="x")
xr.where(a.x % 2 == 1, a, datetime.datetime(2019, 6, 30))
in that case, the check should succeed, because we can convert a builtin datetime object to datetime64[ns].
I committed my (what I consider ugly) implementation of your original approach @keewis. I'm still not sure I understand how to approach the scalar compatibility so if someone has some ideas then please make some suggestion comments or commits directly if you have the permissions.
this might be cleaner:
def asarray(data, xp=np, dtype=None):
return data if is_duck_array(data) else xp.asarray(data, dtype=dtype)
def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
# as soon as extension arrays are involved we only use this:
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
raise ValueError(
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
)
if array_type_cupy := array_type("cupy") and any( # noqa: F841
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821
):
import cupy as cp
xp_ = cp
else:
xp_ = xp
# split into python scalars and arrays / numpy scalars (i.e. into weakly and strongly dtyped)
with_dtype = {}
python_scalars = {}
for index, elem in enumerate(scalars_or_arrays):
append_to = with_dtype if hasattr(elem, "dtype") else python_scalars
append_to[index] = elem
if with_dtype:
to_convert = with_dtype
else:
# can't avoid using the default dtypes if we only get weak dtypes
to_convert = python_scalars
python_scalars = {}
arrays = {index: asarray(x, xp=xp_) for index, x in to_convert.items()}
common_dtype = dtypes.result_type(*arrays.values())
# TODO(keewis): check that all python scalars are compatible. If not, change the dtype or raise.
# cast arrays
cast = {index: astype(x, dtype=common_dtype, copy=False) for index, x in arrays.items()}
# convert python scalars to arrays with a specific dtype
converted = {index: asarray(x, xp=xp_, dtype=common_dtype) for index, x in python_scalars.items()}
# merge both
combined = cast | converted
return [x for _, x in sorted(combined.items(), key=lambda x: x[0])]
This is still missing the dtype fallbacks, though.
I see now why the dtype fallbacks for scalars is tricky... we basically need to enumerate the casting rules, and decide when to return a different dtype (like object). numpy has can_cast with the option to choose the strictness (so we could use "same_kind") and it accepts python scalar types, while the Array API does not allow that choice, and we also can't pass in python scalar types.
To start, here's the rules from the Array API:
-
complexdtypes are compatible withint,float, orcomplex -
floatdtypes are compatible with anyintorfloat -
intdtypes are compatible withint(but beware: python uses BigInt, so the value might exceed the maximum of the dtype) - the
booldtype is only compatible withbool
From numpy, we also have these (numpy casting is even more relaxed than this, but that behavior may also cause some confusing issues):
-
boolcan be cast toint, so it is compatible with anythingintis compatible with -
strdtypes are only compatible withstr. Anything else, like formatting and casting to other types, has to be done explicitly before callingas_shared_dtype. -
datetimedtypes (precisions) are compatible withdatetime.datetime,datetime.date, andpd.Timestamp -
timedeltadtypes (precisions) are compatible withdatetime.timedeltaandpd.Timedelta. Casting tointis possible, but has to be done explicitly (i.e. we can ignore it here) - anything else results in a
objectdtype
Edit: it appears NEP 50 describes the changes in detail. I didn't see that before writing both the list above and implementing the changes, so I might have to change both.
here's my shot at the ~scalar dtype verification~ (the final implementation we settled on in the end is much better). I'm pretty sure it can be cleaned up further (and we need more tests), but it does fix all the casting issues. Edit: note that this depends on the Array API fixes for numpy>=2.
What I don't like is that we're essentially hard-coding the dtype casting hierarchy, but I couldn't figure out a way to make it work without that.
FYI to everyone watching this, I'm going to be switching to a heavier paternity leave than I was already starting this week. I think someone else should take this PR over as I don't think I'll have time to finish it in time for the numpy 2 final release.
In an ideal world, I think this would be written something like:
def as_shared_dtype(scalars_or_arrays):
xp = get_array_namespace_or_numpy(scalars_or_arrays)
dtype = xp.result_type(*scalars_or_arrays)
return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)
The main issues stopping this:
- cupy, pandas and old versions of numpy don't support the array API
-
xp.result_typeonly supports arrays, not Python scalars
The first issue can be solved with compatibility code. I will raise the second issue on the array API tracker.
while 2 should eventually be resolved by an addition to the Array API, it won't help us right now to resolve the dtype casting before the release of numpy=2.0.
As far as I understand it, to put the above code sample (with slight code modifications due to 1) in as_shared_dtypes, we'd have to add compatibility code to dtypes.result_type, at least until we can require a version of the Array API that allows us to forward most of it to xp.result_type.
To do so, we'd have to find a way to split the input of dtypes.result_type into weakly dtyped and explicitly dtyped / dtypes (since dtypes.result_type and xp.result_type accept arrays / explicitly dtyped scalars or dtype objects). Then we can forward the latter to xp.result_type, and figure out what to do with the weakly dtyped data in an additional step.
However, while with numpy we can simply use isinstance(x, np.dtype) to find dtypes, this won't help us with other Array API-implementing libraries as the dtypes are generally opaque objects, and we also don't want to lose the ability to use where on dtype=object arrays. In other words, I can't find a way to separate weakly dtyped data from dtypes and explicitly dtyped data.
If there truly is no way to find dtypes in a general way, we'll have to do the split in as_shared_dtype, where we can guarantee that we don't get dtype objects:
def as_shared_dtype(scalars_or_arrays):
xp = get_array_namespace_or_numpy(scalars_or_arrays)
explicitly_dtyped, weakly_dtyped = dtypes.extract_explicitly_dtyped(scalars_or_arrays)
common_dtype = dtypes.result_type(*explicitly_dtyped)
dtype = dtypes.adapt_common_dtype(common_dtype, weakly_dtyped)
return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)
Another option would be to pass the weakly dtyped data as a keyword argument to dtypes.result_type, which in the future would allow us to pass both to xp.result_type if we know that a specific library supports python scalars.
Edit: actually, I'd probably go for the latter.
In the most recent commits I've added a way to check if a dtype is compatible with the scalars, using the second option from https://github.com/pydata/xarray/pull/8946#issuecomment-2125901039: split into weakly / strongly dtyped in as_shared_dtype, then pass both to result_type (but as separate arguments).
This appears to resolve the issues we had with dtype casting and numpy>=2, but there's a few other issues that pop up. For example, cupy doesn't have cupy.astype, and the failing pint tests seem to hint at an issue in my algorithm and in the existing non-pint test coverage (not sure, though).
I'll try to look into the above and into adding additional tests over the weekend.
In the meantime, I'd appreciate reviews for the general idea (cc in particular @shoyer, but also @dcherian).
looks like the most recent commits fixed the remaining failing tests (turns out I forgot to apply our custom dtype casting rules when adjusting the dtype to fit the scalars), so all that's left is to fix mypy and write tests.
Hey @keewis, thanks for continuing to dive into this!
Given that there seems to be a concensus to add support for weak/strong dtypes into the Array API's result_type in the future, let's try to build this code around the assumption that that will work in the future.
I think this could look something the following:
def _future_array_api_result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
xp,
) -> np.dtype:
...
def result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
xp=None,
) -> np.dtype:
from xarray.core.duck_array_ops import get_array_namespace
if xp is None:
xp = get_array_namespace(arrays_and_dtypes)
types = {xp.result_type(t) for t in arrays_and_dtypes}
if all(isinstance(t, np.dtype) for t in types): # NOTE: slightly more conservative than the existing code
# only check if there's numpy dtypes – the array API does not
# define the types we're checking for
for left, right in PROMOTE_TO_OBJECT:
if any(np.issubdtype(t, left) for t in types) and any(
np.issubdtype(t, right) for t in types
):
return np.dtype(object)
if xp is np:
return np.result_dtype(*array_and_dtype) # fast path
# TODO: replace with xp.result_type when the array API always supports weak dtypes:
# https://github.com/data-apis/array-api/issues/805
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)
that would be ideal, yes. However, I think there are a couple of issues that prevent us from implementing it the way you suggested (I think):
-
xp.result_typereturns the default dtype for a dtype kind for python scalars, and we can make use of that to only apply the custom casting rules once. However,np.result_type("a")returnsdtype('S')(i.e.np.bytes_), and our custominf,ninfandnanobjects will be raise. As a consequence, we need custom translation rules for scalars. - as far as I can tell, implementing the fallback version of
result_type(_future_array_api_result_type) requires being able to separate weakly dtyped values from strongly dtyped values and dtype objects. However, since the Array API defines dtype objects as opaque objects and doesn't give any guarantees about it, I can't seem to figure out a way to actually do that. - we kinda support
xr.where(False, 0, 1)right now (even though that feels kinda weird to have), which is going to be undefined behavior if I read the Array API discussion correctly. If we want to continue supporting that, we have to be able to use the default dtypes in that case. As far as I can tell, that also requires 2.
Thinking a bit more about 2, maybe instead of if xp is np we can use if any(isinstance(t, np.dtype) for t in arrays_and_dtypes). That way, libraries that make use of numpy dtypes (datetime, timedelta, string, object) like cupy or sparse can use their version of result_type which should already support these.
All the considerations above can be satisfied by
def _future_array_api_result_type(*arrays_and_dtypes, weakly_dtyped, xp):
dtype = xp.result_type(*arrays_and_dtypes, *weakly_dtyped)
if weakly_dtyped is None or is_object(dtype):
return dtype
possible_dtypes = {
complex: "complex64",
float: "float32",
int: "int8",
bool: "bool",
str: "str",
}
dtypes = [possible_dtypes.get(type(x), "object") for x in weakly_dtyped]
common_dtype = xp.result_type(dtype, *dtypes)
return common_dtype
def determine_types(t, xp):
if isinstance(t, str):
return np.dtype("U")
elif isinstance(t, (AlwaysGreaterThan, AlwaysLessThan, utils.ReprObject)):
return object
else:
return xp.result_type(t)
def result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
weakly_dtyped=None,
xp=None,
) -> np.dtype:
from xarray.core.duck_array_ops import asarray, get_array_namespace
if xp is None:
xp = get_array_namespace(arrays_and_dtypes)
if weakly_dtyped is None:
weakly_dtyped = []
if not arrays_and_dtypes:
# no explicit dtypes, so we simply convert to 0-d arrays using default dtypes
arrays_and_dtypes = [asarray(x, xp=xp) for x in weakly_dtyped] # type: ignore
weakly_dtyped = []
types = {determine_types(t, xp=xp) for t in [*arrays_and_dtypes, *weakly_dtyped]}
if any(isinstance(t, np.dtype) for t in types):
# only check if there's numpy dtypes – the array API does not
# define the types we're checking for
for left, right in PROMOTE_TO_OBJECT:
if any(np.issubdtype(t, left) for t in types) and any(
np.issubdtype(t, right) for t in types
):
return np.dtype(object)
if xp is np or any(isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes):
return xp.result_type(*arrays_and_dtypes, *weakly_dtyped)
return _future_array_api_result_type(*arrays_and_dtypes, weakly_dtyped=weakly_dtyped, xp=xp)
That still requires passing the weakly dtyped values separately into result_type, which as I mentioned is because I couldn't find a way to implement 2 and 3 without this.
Edit: Additionally, we have to explicitly cast the other parameter of duck_array_ops.fillna to an array or modify some of the tests where we pass list / range objects.
Edit2: might have to remove the custom inf / ninf / na objects, though
Edit3: another issue: np.result_type will try to consider strings / bytes as dtype names, not as scalars.
Edit4: there's a lot of other edge cases. We'll have to try and sort this out at a later time
as a summary of the discussion we just had, and so I don't forget in the time until I implement this (cc @shoyer):
- splitting out weakly-dtyped values from dtypes / strongly-dtyped values should happen in our fallback version of the Array API's
result_type(the aim being to be a drop-in replacement for a version of the Array API that supports python scalars inresult_type) - we consider only the numeric types (i.e.
bool,int,floatandcomplex) andstr/bytes, allowing us to separate dtype objects from scalars - dtype name strings will not be passed around from within
xarray's code, and so will never reach our version ofresult_type(though whether that's something that already happens is something to be verified) - datetime / timedelta objects are currently not supported, but may be in the future (after finally releasing the numpy 2-compatible version of
xarray)
I decided to do this now rather than later. Good news is that this is finally ready for a review and possibly even merging (cc @shoyer).
Edit: ~actually, no. Trying to run this using a default environment but with the release candidate reveals a lot of errors. I'm investigating.~ looks like most of this was just numba being incompatible with numpy>=2.0.0.rc1. ~Still investigating the rest.~ more environment issues, but there's a couple of issues which should be fixed in the final release of numpy=2:
FAILED xarray/tests/test_dtypes.py::test_maybe_promote[q-expected19] - AssertionError: assert dtype('O') == <class 'numpy.float64'>
FAILED xarray/tests/test_dtypes.py::test_maybe_promote[Q-expected20] - AssertionError: assert dtype('O') == <class 'numpy.float64'>
However, there's also these:
FAILED xarray/tests/test_conventions.py::TestCFEncodedDataStore::test_roundtrip_mask_and_scale[dtype0-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data] - OverflowError: Failed to decode variable 'x': Python integer -1 out of bounds for uint8
FAILED xarray/tests/test_conventions.py::TestCFEncodedDataStore::test_roundtrip_mask_and_scale[dtype1-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data] - OverflowError: Failed to decode variable 'x': Python integer -1 out of bounds for uint8
~which require further investigation~
Edit: all good, the last one is skipped in the nightly builds because we don't have numpy>=2-compatible builds of netcdf4, yet. Once we do, we'll have to revisit.
if my most recent changes are fine, this should be ready for merging (the remaining upstream-dev test failures will be fixed by #9081).
Once that is done, I will cut a release to have at least one release that is compatible with numpy>=2 before that is released.
Wow. Thanks @keewis :clap: :clap: