xarray icon indicating copy to clipboard operation
xarray copied to clipboard

Fix upcasting with python builtin numbers and numpy 2

Open djhoese opened this issue 1 year ago • 21 comments

See #8402 for more discussion. Bottom line is that numpy 2 changes the rules for casting between two inputs. Due to this and xarray's preference for promoting python scalars to 0d arrays (scalar arrays), xarray objects are being upcast to higher data types when they previously didn't.

I'm mainly opening this PR for further and more detailed discussion.

CC @dcherian

  • [ ] Closes #8402
  • [ ] Tests added
  • [ ] User visible changes (including notable bug fixes) are documented in whats-new.rst
  • [ ] New functions/methods are listed in api.rst

djhoese avatar Apr 15 '24 20:04 djhoese

Ugh my local clone was so old it was pointing to master. One sec...

djhoese avatar Apr 15 '24 20:04 djhoese

Ok so the failing test is the array-api version (https://github.com/data-apis/array-api-compat) where it expects both the x and y inputs of the where function to be .dtype. Since we're skipping scalar->array conversion in this PR those objects won't have a .dtype. I'm not sure what the rules are for the strict array API having scalar inputs.

djhoese avatar Apr 15 '24 21:04 djhoese

Looks like the array api strictly wants arrays: https://data-apis.org/array-api/latest/API_specification/generated/array_api.where.html

dcherian avatar Apr 15 '24 21:04 dcherian

Related but I don't fully understand it: https://github.com/data-apis/array-api-compat/issues/85

djhoese avatar Apr 15 '24 21:04 djhoese

I guess it depends how you interpret the array API standard then. I can file an issue if needed. To me, depending on how you read the standard, it means either:

  1. This test is flawed as it tests scalar inputs when the array API specifically defines Array inputs.
  2. The Array API package is flawed because it assumes and requires Array inputs when the standard allows for scalar inputs (I don't think this is true if I'm understanding the description).

The other point is that maybe numpy compatibility is more important until numpy more formally conforms to the array API standard (see the first note on https://data-apis.org/array-api/latest/API_specification/array_object.html#api-specification-array-object--page-root). But also type promotion seems wishy-washy and not super strict: https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars

I propose, because it works best for me and matches numpy compatibility, that I update the test to have a numpy case only but add a new test function with numpy and array api cases with array inputs to .where instead of scalars.

djhoese avatar Apr 16 '24 02:04 djhoese

I lean towards (1).

I looked at this for a while, and we'll need major changes around handling array API dtype objects to do this properly.

cc @keewis

dcherian avatar Apr 18 '24 14:04 dcherian

we'll need major changes around handling array API dtype objects to do this properly.

I think the change could be limited to xarray.core.duck_array_ops.as_shared_dtype. According to the Array API section on mixing scalars and arrays, we should to use the dtype of the array (though it only looks at scalar + 1 array, so we'd need to extend that).

However, what we currently do is cast all scalars to arrays using asarray, which means python scalars use the OS default dtype (e.g. float64 on most 64-bit systems).

As a algorithm, maybe this could work:

  • separate the input into python scalars and arrays / scalars with dtype
  • determine result_type using just the arrays / scalars with dtype
  • check that all python scalars are compatible with the result (otherwise might have to return object?)
  • cast all input to arrays with the dtype

keewis avatar Apr 22 '24 09:04 keewis

According to the Array API section on mixing scalars and arrays, we should to use the dtype of the array (though it only looks at scalar + 1 array, so we'd need to extend that).

Do you know if this is inline with numpy 2 dtype casting behavior?

djhoese avatar Apr 22 '24 15:04 djhoese

The main numpy namespace is supposed to be Array API compatible, so it should? I don't know for certain, though.

keewis avatar Apr 22 '24 16:04 keewis

check that all python scalars are compatible with the result (otherwise might have to return object?)

How do we check this?

djhoese avatar Apr 22 '24 19:04 djhoese

Here's what I have locally which seems to pass:

Subject: [PATCH] Cast scalars as arrays with result type of only arrays
---
Index: xarray/core/duck_array_ops.py
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
--- a/xarray/core/duck_array_ops.py	(revision e27f572585a6386729a5523c1f9082c72fa8d178)
+++ b/xarray/core/duck_array_ops.py	(date 1713816523554)
@@ -239,20 +239,30 @@
         import cupy as cp
 
         arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
+        # Pass arrays directly instead of dtypes to result_type so scalars
+        # get handled properly.
+        # Note that result_type() safely gets the dtype from dask arrays without
+        # evaluating them.
+        out_type = dtypes.result_type(*arrays)
     else:
-        arrays = [
-            # https://github.com/pydata/xarray/issues/8402
-            # https://github.com/pydata/xarray/issues/7721
-            x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
-            for x in scalars_or_arrays
-        ]
-    # Pass arrays directly instead of dtypes to result_type so scalars
-    # get handled properly.
-    # Note that result_type() safely gets the dtype from dask arrays without
-    # evaluating them.
-    out_type = dtypes.result_type(*arrays)
+        # arrays = [
+        #     # https://github.com/pydata/xarray/issues/8402
+        #     # https://github.com/pydata/xarray/issues/7721
+        #     x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp)
+        #     for x in scalars_or_arrays
+        # ]
+        objs_with_dtype = [obj for obj in scalars_or_arrays if hasattr(obj, "dtype")]
+        if objs_with_dtype:
+            # Pass arrays directly instead of dtypes to result_type so scalars
+            # get handled properly.
+            # Note that result_type() safely gets the dtype from dask arrays without
+            # evaluating them.
+            out_type = dtypes.result_type(*objs_with_dtype)
+        else:
+            out_type = dtypes.result_type(*scalars_or_arrays)
+        arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
     return [
-        astype(x, out_type, copy=False) if hasattr(x, "dtype") else x for x in arrays
+        astype(x, out_type, copy=False) for x in arrays
     ]
 
 

I just through it together to see if it would work. I'm not sure it is accurate, but the fact that it is almost exactly like the existing solution with the only difference being the out_type = changes makes me feel this is going in a good direction.

Note I had to do if objs_with_dtype: because the test passes two python scalars so there are no arrays to determine the result type.

djhoese avatar Apr 22 '24 20:04 djhoese

How do we check this?

Not sure... but there are only so many builtin types that can be involved without requiring object dtype, so we could just enumerate all of them? As far as I can tell, that would be: bool, int, float, str, datetime/date, and timedelta

keewis avatar Apr 22 '24 21:04 keewis

check that all python scalars are compatible with the result (otherwise might have to return object?)

How do we check this?

@keewis Do you have a test that I can add to verify any fix I attempt for this? What do you mean by python scalar being compatible with the result?

djhoese avatar Apr 26 '24 01:04 djhoese

well, for example, what should happen for this:

a = xr.DataArray(np.array([1, 2, 3], dtype="int8"), dim="x")
xr.where(a % 2 == 1, a, 1.2)

according to the algorithm above, we have one array of dtype int8, so that means we'd have to check if 1.2 (a float) is compatible with int8. It is not, so we should promote everything to float (the default would be to use float64, which might be a bit weird).

Something similar:

a = xr.DataArray(np.array(["2019-01-01", "2020-01-01"], dtype="datetime64[ns]"), dim="x")
xr.where(a.x % 2 == 1, a, datetime.datetime(2019, 6, 30))

in that case, the check should succeed, because we can convert a builtin datetime object to datetime64[ns].

keewis avatar Apr 26 '24 19:04 keewis

I committed my (what I consider ugly) implementation of your original approach @keewis. I'm still not sure I understand how to approach the scalar compatibility so if someone has some ideas then please make some suggestion comments or commits directly if you have the permissions.

djhoese avatar Apr 28 '24 02:04 djhoese

this might be cleaner:

def asarray(data, xp=np, dtype=None):
    return data if is_duck_array(data) else xp.asarray(data, dtype=dtype)


def as_shared_dtype(scalars_or_arrays, xp=np):
    """Cast a arrays to a shared dtype using xarray's type promotion rules."""
    if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
        # as soon as extension arrays are involved we only use this:
        extension_array_types = [
            x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
        ]
        if len(extension_array_types) == len(scalars_or_arrays) and all(
            isinstance(x, type(extension_array_types[0])) for x in extension_array_types
        ):
            return scalars_or_arrays
        raise ValueError(
            f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
        )

    if array_type_cupy := array_type("cupy") and any(  # noqa: F841
        isinstance(x, array_type_cupy) for x in scalars_or_arrays  # noqa: F821
    ):
        import cupy as cp

        xp_ = cp
    else:
        xp_ = xp

    # split into python scalars and arrays / numpy scalars (i.e. into weakly and strongly dtyped)
    with_dtype = {}
    python_scalars = {}
    for index, elem in enumerate(scalars_or_arrays):
        append_to = with_dtype if hasattr(elem, "dtype") else python_scalars
        append_to[index] = elem

    if with_dtype:
        to_convert = with_dtype
    else:
        # can't avoid using the default dtypes if we only get weak dtypes
        to_convert = python_scalars
        python_scalars = {}

    arrays = {index: asarray(x, xp=xp_) for index, x in to_convert.items()}

    common_dtype = dtypes.result_type(*arrays.values())
    # TODO(keewis): check that all python scalars are compatible. If not, change the dtype or raise.

    # cast arrays
    cast = {index: astype(x, dtype=common_dtype, copy=False) for index, x in arrays.items()}
    # convert python scalars to arrays with a specific dtype
    converted = {index: asarray(x, xp=xp_, dtype=common_dtype) for index, x in python_scalars.items()}

    # merge both
    combined = cast | converted
    return [x for _, x in sorted(combined.items(), key=lambda x: x[0])]

This is still missing the dtype fallbacks, though.

keewis avatar Apr 28 '24 10:04 keewis

I see now why the dtype fallbacks for scalars is tricky... we basically need to enumerate the casting rules, and decide when to return a different dtype (like object). numpy has can_cast with the option to choose the strictness (so we could use "same_kind") and it accepts python scalar types, while the Array API does not allow that choice, and we also can't pass in python scalar types.

To start, here's the rules from the Array API:

  • complex dtypes are compatible with int, float, or complex
  • float dtypes are compatible with any int or float
  • int dtypes are compatible with int (but beware: python uses BigInt, so the value might exceed the maximum of the dtype)
  • the bool dtype is only compatible with bool

From numpy, we also have these (numpy casting is even more relaxed than this, but that behavior may also cause some confusing issues):

  • bool can be cast to int, so it is compatible with anything int is compatible with
  • str dtypes are only compatible with str. Anything else, like formatting and casting to other types, has to be done explicitly before calling as_shared_dtype.
  • datetime dtypes (precisions) are compatible with datetime.datetime, datetime.date, and pd.Timestamp
  • timedelta dtypes (precisions) are compatible with datetime.timedelta and pd.Timedelta. Casting to int is possible, but has to be done explicitly (i.e. we can ignore it here)
  • anything else results in a object dtype

Edit: it appears NEP 50 describes the changes in detail. I didn't see that before writing both the list above and implementing the changes, so I might have to change both.

keewis avatar Apr 28 '24 11:04 keewis

here's my shot at the ~scalar dtype verification~ (the final implementation we settled on in the end is much better). I'm pretty sure it can be cleaned up further (and we need more tests), but it does fix all the casting issues. Edit: note that this depends on the Array API fixes for numpy>=2.

What I don't like is that we're essentially hard-coding the dtype casting hierarchy, but I couldn't figure out a way to make it work without that.

keewis avatar Apr 28 '24 16:04 keewis

FYI to everyone watching this, I'm going to be switching to a heavier paternity leave than I was already starting this week. I think someone else should take this PR over as I don't think I'll have time to finish it in time for the numpy 2 final release.

djhoese avatar May 12 '24 16:05 djhoese

In an ideal world, I think this would be written something like:

def as_shared_dtype(scalars_or_arrays):
    xp = get_array_namespace_or_numpy(scalars_or_arrays)
    dtype = xp.result_type(*scalars_or_arrays)
    return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)

The main issues stopping this:

  1. cupy, pandas and old versions of numpy don't support the array API
  2. xp.result_type only supports arrays, not Python scalars

The first issue can be solved with compatibility code. I will raise the second issue on the array API tracker.

shoyer avatar May 15 '24 16:05 shoyer

while 2 should eventually be resolved by an addition to the Array API, it won't help us right now to resolve the dtype casting before the release of numpy=2.0.

As far as I understand it, to put the above code sample (with slight code modifications due to 1) in as_shared_dtypes, we'd have to add compatibility code to dtypes.result_type, at least until we can require a version of the Array API that allows us to forward most of it to xp.result_type.

To do so, we'd have to find a way to split the input of dtypes.result_type into weakly dtyped and explicitly dtyped / dtypes (since dtypes.result_type and xp.result_type accept arrays / explicitly dtyped scalars or dtype objects). Then we can forward the latter to xp.result_type, and figure out what to do with the weakly dtyped data in an additional step.

However, while with numpy we can simply use isinstance(x, np.dtype) to find dtypes, this won't help us with other Array API-implementing libraries as the dtypes are generally opaque objects, and we also don't want to lose the ability to use where on dtype=object arrays. In other words, I can't find a way to separate weakly dtyped data from dtypes and explicitly dtyped data.

If there truly is no way to find dtypes in a general way, we'll have to do the split in as_shared_dtype, where we can guarantee that we don't get dtype objects:

def as_shared_dtype(scalars_or_arrays):
    xp = get_array_namespace_or_numpy(scalars_or_arrays)

    explicitly_dtyped, weakly_dtyped = dtypes.extract_explicitly_dtyped(scalars_or_arrays)
    common_dtype = dtypes.result_type(*explicitly_dtyped)
    dtype = dtypes.adapt_common_dtype(common_dtype, weakly_dtyped)

    return tuple(xp.asarray(x, dtype) for x in scalars_or_arrays)

Another option would be to pass the weakly dtyped data as a keyword argument to dtypes.result_type, which in the future would allow us to pass both to xp.result_type if we know that a specific library supports python scalars.

Edit: actually, I'd probably go for the latter.

keewis avatar May 22 '24 22:05 keewis

In the most recent commits I've added a way to check if a dtype is compatible with the scalars, using the second option from https://github.com/pydata/xarray/pull/8946#issuecomment-2125901039: split into weakly / strongly dtyped in as_shared_dtype, then pass both to result_type (but as separate arguments).

This appears to resolve the issues we had with dtype casting and numpy>=2, but there's a few other issues that pop up. For example, cupy doesn't have cupy.astype, and the failing pint tests seem to hint at an issue in my algorithm and in the existing non-pint test coverage (not sure, though).

I'll try to look into the above and into adding additional tests over the weekend.

In the meantime, I'd appreciate reviews for the general idea (cc in particular @shoyer, but also @dcherian).

keewis avatar May 23 '24 21:05 keewis

looks like the most recent commits fixed the remaining failing tests (turns out I forgot to apply our custom dtype casting rules when adjusting the dtype to fit the scalars), so all that's left is to fix mypy and write tests.

keewis avatar May 25 '24 21:05 keewis

Hey @keewis, thanks for continuing to dive into this!

Given that there seems to be a concensus to add support for weak/strong dtypes into the Array API's result_type in the future, let's try to build this code around the assumption that that will work in the future.

I think this could look something the following:

def _future_array_api_result_type(
    *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
    xp,
) -> np.dtype:
    ...

def result_type(
    *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
    xp=None,
) -> np.dtype:
    from xarray.core.duck_array_ops import get_array_namespace

    if xp is None:
        xp = get_array_namespace(arrays_and_dtypes)

    types = {xp.result_type(t) for t in arrays_and_dtypes}
    if all(isinstance(t, np.dtype) for t in types):  # NOTE: slightly more conservative than the existing code
        # only check if there's numpy dtypes – the array API does not
        # define the types we're checking for
        for left, right in PROMOTE_TO_OBJECT:
            if any(np.issubdtype(t, left) for t in types) and any(
                np.issubdtype(t, right) for t in types
            ):
                return np.dtype(object)

    if xp is np:
        return np.result_dtype(*array_and_dtype)  # fast path

    # TODO: replace with xp.result_type when the array API always supports weak dtypes:
    # https://github.com/data-apis/array-api/issues/805
    return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)

shoyer avatar May 31 '24 17:05 shoyer

that would be ideal, yes. However, I think there are a couple of issues that prevent us from implementing it the way you suggested (I think):

  1. xp.result_type returns the default dtype for a dtype kind for python scalars, and we can make use of that to only apply the custom casting rules once. However, np.result_type("a") returns dtype('S') (i.e. np.bytes_), and our custom inf, ninf and nan objects will be raise. As a consequence, we need custom translation rules for scalars.
  2. as far as I can tell, implementing the fallback version of result_type (_future_array_api_result_type) requires being able to separate weakly dtyped values from strongly dtyped values and dtype objects. However, since the Array API defines dtype objects as opaque objects and doesn't give any guarantees about it, I can't seem to figure out a way to actually do that.
  3. we kinda support xr.where(False, 0, 1) right now (even though that feels kinda weird to have), which is going to be undefined behavior if I read the Array API discussion correctly. If we want to continue supporting that, we have to be able to use the default dtypes in that case. As far as I can tell, that also requires 2.

Thinking a bit more about 2, maybe instead of if xp is np we can use if any(isinstance(t, np.dtype) for t in arrays_and_dtypes). That way, libraries that make use of numpy dtypes (datetime, timedelta, string, object) like cupy or sparse can use their version of result_type which should already support these.

keewis avatar Jun 05 '24 12:06 keewis

All the considerations above can be satisfied by

def _future_array_api_result_type(*arrays_and_dtypes, weakly_dtyped, xp):
    dtype = xp.result_type(*arrays_and_dtypes, *weakly_dtyped)
    if weakly_dtyped is None or is_object(dtype):
        return dtype

    possible_dtypes = {
        complex: "complex64",
        float: "float32",
        int: "int8",
        bool: "bool",
        str: "str",
    }
    dtypes = [possible_dtypes.get(type(x), "object") for x in weakly_dtyped]
    common_dtype = xp.result_type(dtype, *dtypes)
    return common_dtype

def determine_types(t, xp):
    if isinstance(t, str):
        return np.dtype("U")
    elif isinstance(t, (AlwaysGreaterThan, AlwaysLessThan, utils.ReprObject)):
        return object
    else:
        return xp.result_type(t)

def result_type(
    *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
    weakly_dtyped=None,
    xp=None,
) -> np.dtype:
    from xarray.core.duck_array_ops import asarray, get_array_namespace

    if xp is None:
        xp = get_array_namespace(arrays_and_dtypes)

    if weakly_dtyped is None:
        weakly_dtyped = []

    if not arrays_and_dtypes:
        # no explicit dtypes, so we simply convert to 0-d arrays using default dtypes
        arrays_and_dtypes = [asarray(x, xp=xp) for x in weakly_dtyped]  # type: ignore
        weakly_dtyped = []

    types = {determine_types(t, xp=xp) for t in [*arrays_and_dtypes, *weakly_dtyped]}
    if any(isinstance(t, np.dtype) for t in types):
        # only check if there's numpy dtypes – the array API does not
        # define the types we're checking for
        for left, right in PROMOTE_TO_OBJECT:
            if any(np.issubdtype(t, left) for t in types) and any(
                np.issubdtype(t, right) for t in types
            ):
                return np.dtype(object)

    if xp is np or any(isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes):
        return xp.result_type(*arrays_and_dtypes, *weakly_dtyped)

    return _future_array_api_result_type(*arrays_and_dtypes, weakly_dtyped=weakly_dtyped, xp=xp)

That still requires passing the weakly dtyped values separately into result_type, which as I mentioned is because I couldn't find a way to implement 2 and 3 without this.

Edit: Additionally, we have to explicitly cast the other parameter of duck_array_ops.fillna to an array or modify some of the tests where we pass list / range objects. Edit2: might have to remove the custom inf / ninf / na objects, though Edit3: another issue: np.result_type will try to consider strings / bytes as dtype names, not as scalars. Edit4: there's a lot of other edge cases. We'll have to try and sort this out at a later time

keewis avatar Jun 05 '24 12:06 keewis

as a summary of the discussion we just had, and so I don't forget in the time until I implement this (cc @shoyer):

  • splitting out weakly-dtyped values from dtypes / strongly-dtyped values should happen in our fallback version of the Array API's result_type (the aim being to be a drop-in replacement for a version of the Array API that supports python scalars in result_type)
  • we consider only the numeric types (i.e. bool, int, float and complex) and str / bytes, allowing us to separate dtype objects from scalars
  • dtype name strings will not be passed around from within xarray's code, and so will never reach our version of result_type (though whether that's something that already happens is something to be verified)
  • datetime / timedelta objects are currently not supported, but may be in the future (after finally releasing the numpy 2-compatible version of xarray)

keewis avatar Jun 07 '24 19:06 keewis

I decided to do this now rather than later. Good news is that this is finally ready for a review and possibly even merging (cc @shoyer).

Edit: ~actually, no. Trying to run this using a default environment but with the release candidate reveals a lot of errors. I'm investigating.~ looks like most of this was just numba being incompatible with numpy>=2.0.0.rc1. ~Still investigating the rest.~ more environment issues, but there's a couple of issues which should be fixed in the final release of numpy=2:

FAILED xarray/tests/test_dtypes.py::test_maybe_promote[q-expected19] - AssertionError: assert dtype('O') == <class 'numpy.float64'>
FAILED xarray/tests/test_dtypes.py::test_maybe_promote[Q-expected20] - AssertionError: assert dtype('O') == <class 'numpy.float64'>

However, there's also these:

FAILED xarray/tests/test_conventions.py::TestCFEncodedDataStore::test_roundtrip_mask_and_scale[dtype0-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data] - OverflowError: Failed to decode variable 'x': Python integer -1 out of bounds for uint8
FAILED xarray/tests/test_conventions.py::TestCFEncodedDataStore::test_roundtrip_mask_and_scale[dtype1-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data] - OverflowError: Failed to decode variable 'x': Python integer -1 out of bounds for uint8

~which require further investigation~

Edit: all good, the last one is skipped in the nightly builds because we don't have numpy>=2-compatible builds of netcdf4, yet. Once we do, we'll have to revisit.

keewis avatar Jun 07 '24 20:06 keewis

if my most recent changes are fine, this should be ready for merging (the remaining upstream-dev test failures will be fixed by #9081).

Once that is done, I will cut a release to have at least one release that is compatible with numpy>=2 before that is released.

keewis avatar Jun 10 '24 10:06 keewis

image

Wow. Thanks @keewis :clap: :clap:

dcherian avatar Jun 10 '24 17:06 dcherian