where/join/concat failures using esoteric dtypes
What happened?
Based on this comment I ran a mini test-suite for checking the current dtype support and it resulted in a few failures for EA and numpy.StringDType edge-cases xarray doesn't cover yet.
The tests are AI generated, but they are meaningful, relevant.
Some of the test cases might be already covered in: https://github.com/pydata/xarray/pull/10423
It's the generalization of: https://github.com/pydata/xarray/issues/10964 (but that's pretty specific, self-contained)
cc @dcherian @ilan-gold
What did you expect to happen?
Clean pass
Minimal Complete Verifiable Example
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "xarray[complete]@git+https://github.com/pydata/xarray.git@main",
# "pyarrow",
# ]
# ///
#
# This script automatically imports the development branch of xarray to check for issues.
# Please delete this header if you have _not_ tested this script with `uv run`!
import xarray as xr
xr.show_versions()
# your reproducer code ...
import numpy as np
import pandas as pd
import xarray as xr
import pyarrow as pa # noqa: F401
from numpy.dtypes import StringDType
def print_header():
print(f"xarray version: {xr.__version__}")
print(f"pandas version: {pd.__version__}")
print(f"numpy version: {np.__version__}")
print(f"pyarrow version: {pa.__version__}")
print()
def run_test(name, func):
try:
func()
except Exception as e:
print(f"FAIL | {name} ({type(e).__name__}: {e})")
else:
print(f"PASS | {name}")
# =============================================================================
# NumPy StringDType tests
# =============================================================================
def test_numpy_stringdtype_values_where():
sdt = StringDType()
data = np.array(["a", "b", "c"], dtype=sdt)
da = xr.DataArray(data, dims="x", coords={"x": [0, 1, 2]}, name="str_val")
_ = da.where(da != "b")
def test_numpy_stringdtype_values_concat():
sdt = StringDType()
data = np.array(["a", "b", "c"], dtype=sdt)
da = xr.DataArray(data, dims="x", coords={"x": [0, 1, 2]}, name="str_val")
_ = xr.concat([da, da], dim="rep")
def test_numpy_stringdtype_coord_align():
sdt = StringDType()
coord = np.array(["A", "B", "C"], dtype=sdt)
da1 = xr.DataArray([1, 2, 3], dims="label", coords={"label": coord}, name="v1")
da2 = xr.DataArray([10, 20], dims="label", coords={"label": ["B", "C"]}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_numpy_stringdtype_values_where_null():
sdt = StringDType(na_object=pd.NA)
data = np.array(["a", pd.NA, "c"], dtype=sdt)
da = xr.DataArray(data, dims="x", coords={"x": [0, 1, 2]}, name="str_val_na")
mask = ~da.isnull()
_ = da.where(mask)
# =============================================================================
# string[pyarrow] tests
# =============================================================================
def test_string_pyarrow_values_where():
s = pd.Series(["foo", "bar", None], dtype="string[pyarrow]", name="s")
da = s.to_xarray()
_ = da.where(da != "foo")
def test_string_pyarrow_values_concat():
s = pd.Series(["foo", "bar", None], dtype="string[pyarrow]", name="s")
da = s.to_xarray()
_ = xr.concat([da, da], dim="rep")
def test_string_pyarrow_values_align():
s = pd.Series(["foo", "bar", None], dtype="string[pyarrow]", name="s")
da1 = s.to_xarray()
da2 = da1.isel(index=[0, 1])
_ = xr.align(da1, da2, join="outer")
def test_string_pyarrow_values_where_null():
s = pd.Series(["foo", None, "bar"], dtype="string[pyarrow]", name="s")
da = s.to_xarray()
mask = ~da.isnull()
_ = da.where(mask)
def test_string_pyarrow_coord_where():
idx = pd.Index(["A", "B", "C"], dtype="string[pyarrow]", name="label")
da = xr.DataArray([1, 2, 3], dims="label", coords={"label": idx}, name="v")
_ = da.where(da > 1)
def test_string_pyarrow_coord_concat():
idx = pd.Index(["A", "B", "C"], dtype="string[pyarrow]", name="label")
da = xr.DataArray([1, 2, 3], dims="label", coords={"label": idx}, name="v")
_ = xr.concat([da, da], dim="rep")
def test_string_pyarrow_coord_align():
idx = pd.Index(["A", "B", "C"], dtype="string[pyarrow]", name="label")
da1 = xr.DataArray([1, 2, 3], dims="label", coords={"label": idx}, name="v1")
da2 = xr.DataArray([10, 20], dims="label", coords={"label": ["B", "D"]}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_string_pyarrow_coord_align_with_null():
idx1 = pd.Index(["A", None, "C"], dtype="string[pyarrow]", name="label")
da1 = xr.DataArray([1, 2, 3], dims="label", coords={"label": idx1}, name="v1")
idx2 = pd.Index(["A", "B"], dtype="string[pyarrow]", name="label")
da2 = xr.DataArray([10, 20], dims="label", coords={"label": idx2}, name="v2")
_ = xr.align(da1, da2, join="outer")
# =============================================================================
# date32[pyarrow] as coordinate
# =============================================================================
def test_date32_coord_where():
dates = pd.date_range("2024-01-01", periods=3, freq="D")
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="val")
_ = da.where(da["time"] >= idx[1])
def test_date32_coord_concat():
dates = pd.date_range("2024-01-01", periods=3, freq="D")
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="val")
_ = xr.concat([da, da], dim="rep")
def test_date32_coord_align_vs_datetime64():
dates = pd.date_range("2024-01-01", periods=3, freq="D")
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da1 = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="v1")
da2 = xr.DataArray(
[1.0, 2.0, 3.0],
dims="time",
coords={"time": pd.date_range("2024-01-01", periods=3, freq="D")},
name="v2",
)
_ = xr.align(da1, da2, join="outer")
def test_date32_coord_where_null():
dates = [
pd.Timestamp("2024-01-01"),
pd.NaT,
pd.Timestamp("2024-01-03"),
]
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="val")
mask = ~da["time"].isnull()
_ = da.where(mask)
# =============================================================================
# Nullable Int64 and int64[pyarrow]
# =============================================================================
def test_int64_nullable_values_where():
s = pd.Series([1, 2, None], dtype="Int64", name="v")
da = s.to_xarray()
_ = da.where(da > 1)
def test_int64_nullable_values_concat():
s = pd.Series([1, 2, None], dtype="Int64", name="v")
da = s.to_xarray()
_ = xr.concat([da, da], dim="rep")
def test_int64_nullable_coord_align():
idx = pd.Index([1, 2, 3], dtype="Int64", name="i")
da1 = xr.DataArray([10, 20, 30], dims="i", coords={"i": idx}, name="v1")
da2 = xr.DataArray([100, 200], dims="i", coords={"i": [2, 4]}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_int64_nullable_values_where_null():
s = pd.Series([1, None, 3], dtype="Int64", name="v")
da = s.to_xarray()
mask = da.notnull()
_ = da.where(mask)
def test_int64_pyarrow_values_where():
s = pd.Series([1, 2, None], dtype="int64[pyarrow]", name="v")
da = s.to_xarray()
_ = da.where(da > 1)
def test_int64_pyarrow_values_concat():
s = pd.Series([1, 2, None], dtype="int64[pyarrow]", name="v")
da = s.to_xarray()
_ = xr.concat([da, da], dim="rep")
def test_int64_pyarrow_coord_align():
idx = pd.Index([1, 2, 3], dtype="int64[pyarrow]", name="i_arrow")
da1 = xr.DataArray([5, 6, 7], dims="i_arrow", coords={"i_arrow": idx}, name="v1")
da2 = xr.DataArray([10, 20], dims="i_arrow", coords={"i_arrow": [2, 4]}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_int64_pyarrow_values_where_null():
s = pd.Series([1, None, 3], dtype="int64[pyarrow]", name="v")
da = s.to_xarray()
mask = da.notnull()
_ = da.where(mask)
# =============================================================================
# Categorical as values and coordinate
# =============================================================================
def test_categorical_values_where():
cat = pd.Categorical(["a", "b", "a", "c"], categories=["a", "b", "c"])
da = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2, 3]}, name="cat_val")
_ = da.where(da != "a")
def test_categorical_values_concat():
cat = pd.Categorical(["a", "b", "a", "c"], categories=["a", "b", "c"])
da = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2, 3]}, name="cat_val")
_ = xr.concat([da, da], dim="rep")
def test_categorical_values_align_vs_object():
cat = pd.Categorical(["a", "b", "a", "c"], categories=["a", "b", "c"])
da1 = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2, 3]}, name="cat_val")
da2 = xr.DataArray(
np.array(["a", "c", "d"], dtype=object),
dims="x",
coords={"x": [0, 1, 2]},
name="obj_val",
)
_ = xr.align(da1, da2, join="outer")
def test_categorical_values_where_null():
cat = pd.Categorical(["a", None, "b", "c"], categories=["a", "b", "c"])
da = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2, 3]}, name="cat_val_na")
mask = ~da.isnull()
_ = da.where(mask)
def test_categorical_coord_where():
cat_idx = pd.CategoricalIndex(["A", "B", "C"], categories=["A", "B", "C"], name="lab")
da = xr.DataArray([1, 2, 3], dims="lab", coords={"lab": cat_idx}, name="v")
_ = da.where(da > 1)
def test_categorical_coord_concat():
cat_idx = pd.CategoricalIndex(["A", "B", "C"], categories=["A", "B", "C"], name="lab")
da = xr.DataArray([1, 2, 3], dims="lab", coords={"lab": cat_idx}, name="v")
_ = xr.concat([da, da], dim="rep")
def test_categorical_coord_align_vs_object_index():
idx_obj = pd.Index(["A", "B"], dtype="object", name="lab")
ds1 = xr.Dataset({"v": ("lab", [1, 2])}, coords={"lab": idx_obj})
idx_cat = pd.CategoricalIndex(["B", "C"], categories=["A", "B", "C"], name="lab")
ds2 = xr.Dataset({"v": ("lab", [3, 4])}, coords={"lab": idx_cat})
_ = xr.align(ds1, ds2, join="inner")
def test_categorical_coord_align_with_null():
cat_idx1 = pd.CategoricalIndex(
["A", None, "C"], categories=["A", "B", "C"], name="lab"
)
da1 = xr.DataArray([1, 2, 3], dims="lab", coords={"lab": cat_idx1}, name="v1")
cat_idx2 = pd.CategoricalIndex(
["A", "B"], categories=["A", "B", "C"], name="lab"
)
da2 = xr.DataArray([10, 20], dims="lab", coords={"lab": cat_idx2}, name="v2")
_ = xr.align(da1, da2, join="outer")
# =============================================================================
# Main
# =============================================================================
def main():
print_header()
tests = [
# NumPy StringDType
("NumPy StringDType values: where", test_numpy_stringdtype_values_where),
("NumPy StringDType values: concat", test_numpy_stringdtype_values_concat),
("NumPy StringDType coord vs string: align", test_numpy_stringdtype_coord_align),
("NumPy StringDType values: where with null", test_numpy_stringdtype_values_where_null),
# string[pyarrow]
("string[pyarrow] values: where", test_string_pyarrow_values_where),
("string[pyarrow] values: concat", test_string_pyarrow_values_concat),
("string[pyarrow] values: align self vs slice", test_string_pyarrow_values_align),
("string[pyarrow] values: where with null", test_string_pyarrow_values_where_null),
("string[pyarrow] coord: where", test_string_pyarrow_coord_where),
("string[pyarrow] coord: concat", test_string_pyarrow_coord_concat),
("string[pyarrow] coord vs object: align", test_string_pyarrow_coord_align),
("string[pyarrow] coord vs object with null: align", test_string_pyarrow_coord_align_with_null),
# date32[pyarrow]
("date32[pyarrow] coord: where", test_date32_coord_where),
("date32[pyarrow] coord: concat", test_date32_coord_concat),
("date32[pyarrow] coord vs datetime64: align", test_date32_coord_align_vs_datetime64),
("date32[pyarrow] coord: where with null", test_date32_coord_where_null),
# Int64 / int64[pyarrow]
("Int64 values: where", test_int64_nullable_values_where),
("Int64 values: concat", test_int64_nullable_values_concat),
("Int64 coord vs int64: align", test_int64_nullable_coord_align),
("Int64 values: where with null", test_int64_nullable_values_where_null),
("int64[pyarrow] values: where", test_int64_pyarrow_values_where),
("int64[pyarrow] values: concat", test_int64_pyarrow_values_concat),
("int64[pyarrow] coord vs int64: align", test_int64_pyarrow_coord_align),
("int64[pyarrow] values: where with null", test_int64_pyarrow_values_where_null),
# Categorical
("Categorical values: where", test_categorical_values_where),
("Categorical values: concat", test_categorical_values_concat),
("Categorical values vs object: align", test_categorical_values_align_vs_object),
("Categorical values: where with null", test_categorical_values_where_null),
("CategoricalIndex coord: where", test_categorical_coord_where),
("CategoricalIndex coord: concat", test_categorical_coord_concat),
("object vs CategoricalIndex coord: align", test_categorical_coord_align_vs_object_index),
("CategoricalIndex coord with null: align", test_categorical_coord_align_with_null),
]
for name, func in tests:
run_test(name, func)
if __name__ == "__main__":
main()
Steps to reproduce
No response
MVCE confirmation
- [x] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- [x] Complete example — the example is self-contained, including all data and the text of any traceback.
- [x] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- [x] New issue — a search of GitHub Issues suggests this is not a duplicate.
- [x] Recent environment — the issue occurs with the latest version of xarray and its dependencies.
Relevant log output
xarray version: 2025.11.0
pandas version: 2.3.3
numpy version: 2.3.5
pyarrow version: 22.0.0
FAIL | NumPy StringDType values: where (DTypePromotionError: The DType <class 'numpy.dtypes.StringDType'> could not be promoted by <class 'numpy.dtypes._PyFloatDType'>. This means that no common DType exists for the given inputs. For example they cannot be stored in a single array unless the dtype is `object`. The full list of DTypes is: (<class 'numpy.dtypes.StringDType'>, <class 'numpy.dtypes._PyFloatDType'>))
PASS | NumPy StringDType values: concat
PASS | NumPy StringDType coord vs string: align
FAIL | NumPy StringDType values: where with null (TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtypes.StringDType'>.)
FAIL | string[pyarrow] values: where (TypeError: boolean value of NA is ambiguous)
PASS | string[pyarrow] values: concat
PASS | string[pyarrow] values: align self vs slice
PASS | string[pyarrow] values: where with null
PASS | string[pyarrow] coord: where
PASS | string[pyarrow] coord: concat
PASS | string[pyarrow] coord vs object: align
PASS | string[pyarrow] coord vs object with null: align
PASS | date32[pyarrow] coord: where
PASS | date32[pyarrow] coord: concat
FAIL | date32[pyarrow] coord vs datetime64: align (TypeError: Cannot interpret 'date32[day][pyarrow]' as a data type)
PASS | date32[pyarrow] coord: where with null
FAIL | Int64 values: where (TypeError: Cannot interpret 'Int64Dtype()' as a data type)
PASS | Int64 values: concat
FAIL | Int64 coord vs int64: align (TypeError: Cannot interpret 'Int64Dtype()' as a data type)
FAIL | Int64 values: where with null (TypeError: Cannot interpret 'Int64Dtype()' as a data type)
FAIL | int64[pyarrow] values: where (TypeError: Cannot interpret 'int64[pyarrow]' as a data type)
FAIL | int64[pyarrow] values: concat (IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices)
FAIL | int64[pyarrow] coord vs int64: align (TypeError: Cannot interpret 'int64[pyarrow]' as a data type)
FAIL | int64[pyarrow] values: where with null (TypeError: Cannot interpret 'int64[pyarrow]' as a data type)
FAIL | Categorical values: where (TypeError: Cannot interpret 'CategoricalDtype(categories=['a', 'b', 'c'], ordered=False, categories_dtype=object)' as a data type)
PASS | Categorical values: concat
PASS | Categorical values vs object: align
FAIL | Categorical values: where with null (TypeError: Cannot interpret 'CategoricalDtype(categories=['a', 'b', 'c'], ordered=False, categories_dtype=object)' as a data type)
PASS | CategoricalIndex coord: where
PASS | CategoricalIndex coord: concat
FAIL | object vs CategoricalIndex coord: align (TypeError: Cannot interpret 'CategoricalDtype(categories=['A', 'B', 'C'], ordered=False, categories_dtype=object)' as a data type)
FAIL | CategoricalIndex coord with null: align (TypeError: Cannot interpret 'CategoricalDtype(categories=['A', 'B', 'C'], ordered=False, categories_dtype=object)' as a data type)
Anything else we need to know?
No response
Environment
An extended version including groupby, sort, sel, values:
This is a longer script
#!/usr/bin/env python
import numpy as np
import pandas as pd
import xarray as xr
import pyarrow as pa # noqa: F401
from numpy.dtypes import StringDType
import xarray.testing as xrt
def print_header():
print(f"xarray version: {xr.__version__}")
print(f"pandas version: {pd.__version__}")
print(f"numpy version: {np.__version__}")
print(f"pyarrow version: {pa.__version__}")
print()
def run_test(name, func):
try:
func()
except Exception as e:
print(f"FAIL | {name} ({type(e).__name__}: {e})")
else:
print(f"PASS | {name}")
# =============================================================================
# NumPy StringDType tests
# =============================================================================
def test_numpy_stringdtype_values_where():
sdt = StringDType()
data = np.array(["a", "b", "c"], dtype=sdt)
da = xr.DataArray(data, dims="x", coords={"x": [0, 1, 2]}, name="str_val")
_ = da.where(da != "b")
def test_numpy_stringdtype_values_concat():
sdt = StringDType()
data = np.array(["a", "b", "c"], dtype=sdt)
da = xr.DataArray(data, dims="x", coords={"x": [0, 1, 2]}, name="str_val")
_ = xr.concat([da, da], dim="rep")
def test_numpy_stringdtype_coord_align():
sdt = StringDType()
coord = np.array(["A", "B", "C"], dtype=sdt)
da1 = xr.DataArray([1, 2, 3], dims="label", coords={"label": coord}, name="v1")
da2 = xr.DataArray([10, 20], dims="label", coords={"label": ["B", "C"]}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_numpy_stringdtype_values_where_null():
sdt = StringDType(na_object=pd.NA)
data = np.array(["a", pd.NA, "c"], dtype=sdt)
da = xr.DataArray(data, dims="x", coords={"x": [0, 1, 2]}, name="str_val_na")
mask = ~da.isnull()
_ = da.where(mask)
def test_numpy_stringdtype_values_values():
sdt = StringDType()
data = np.array(["a", "b", "c"], dtype=sdt)
da = xr.DataArray(data, dims="x", name="str_val")
_ = da.values
def _check_sortby_self_values(da):
"""Generic helper: sort by own values, check monotonic non-null part and trailing nulls."""
out = da.sortby(da)
vals = list(out.values)
non_null = []
null_flags = []
for v in vals:
is_null = pd.isna(v)
null_flags.append(is_null)
if not is_null:
non_null.append(v)
# Non-null portion should be sorted (Python default order)
# Note: for mixed types this might fail, but here we only use homogeneous examples.
assert non_null == sorted(non_null), f"Non-null values not sorted: {non_null}"
# Once we hit a null, everything after should be null
if any(null_flags):
first_null_idx = null_flags.index(True)
assert all(null_flags[first_null_idx:]), "Nulls are not all trailing in sorted result"
def test_numpy_stringdtype_values_sortby_self():
sdt = StringDType()
data = np.array(["b", "a", "c"], dtype=sdt)
da = xr.DataArray(data, dims="x", name="str_val")
_check_sortby_self_values(da)
# =============================================================================
# string[pyarrow] tests
# =============================================================================
def test_string_pyarrow_values_where():
s = pd.Series(["foo", "bar", None], dtype="string[pyarrow]", name="s")
da = s.to_xarray()
_ = da.where(da != "foo")
def test_string_pyarrow_values_concat():
s = pd.Series(["foo", "bar", None], dtype="string[pyarrow]", name="s")
da = s.to_xarray()
_ = xr.concat([da, da], dim="rep")
def test_string_pyarrow_values_align():
s = pd.Series(["foo", "bar", None], dtype="string[pyarrow]", name="s")
da1 = s.to_xarray()
da2 = da1.isel(index=[0, 1])
_ = xr.align(da1, da2, join="outer")
def test_string_pyarrow_values_where_null():
s = pd.Series(["foo", None, "bar"], dtype="string[pyarrow]", name="s")
da = s.to_xarray()
mask = ~da.isnull()
_ = da.where(mask)
def test_string_pyarrow_values_values():
s = pd.Series(["foo", "bar", None], dtype="string[pyarrow]", name="s")
da = s.to_xarray()
_ = da.values
def test_string_pyarrow_values_sortby_self():
s = pd.Series(["b", "a", None], dtype="string[pyarrow]", name="s")
da = s.to_xarray()
_check_sortby_self_values(da)
def test_string_pyarrow_coord_where():
idx = pd.Index(["A", "B", "C"], dtype="string[pyarrow]", name="label")
da = xr.DataArray([1, 2, 3], dims="label", coords={"label": idx}, name="v")
_ = da.where(da > 1)
def test_string_pyarrow_coord_concat():
idx = pd.Index(["A", "B", "C"], dtype="string[pyarrow]", name="label")
da = xr.DataArray([1, 2, 3], dims="label", coords={"label": idx}, name="v")
_ = xr.concat([da, da], dim="rep")
def test_string_pyarrow_coord_align():
idx = pd.Index(["A", "B", "C"], dtype="string[pyarrow]", name="label")
da1 = xr.DataArray([1, 2, 3], dims="label", coords={"label": idx}, name="v1")
da2 = xr.DataArray([10, 20], dims="label", coords={"label": ["B", "D"]}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_string_pyarrow_coord_align_with_null():
idx1 = pd.Index(["A", None, "C"], dtype="string[pyarrow]", name="label")
da1 = xr.DataArray([1, 2, 3], dims="label", coords={"label": idx1}, name="v1")
idx2 = pd.Index(["A", "B"], dtype="string[pyarrow]", name="label")
da2 = xr.DataArray([10, 20], dims="label", coords={"label": idx2}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_string_pyarrow_coord_index_ops():
idx = pd.Index(["C", "A", "B"], dtype="string[pyarrow]", name="label")
da = xr.DataArray([3, 1, 2], dims="label", coords={"label": idx}, name="v")
_ = da.sel(label="B")
_ = da.sel(label=["A", "C"])
_ = da.loc[dict(label="A")]
_ = da.sortby("label")
def test_string_pyarrow_groupby_dropna():
idx = pd.Index(["A", "B", "A", "C"], dtype="string[pyarrow]", name="label")
da = xr.DataArray([1.0, 2.0, 3.0, np.nan], dims="label", coords={"label": idx}, name="v")
_ = da.groupby("label").sum()
_ = da.dropna("label")
# =============================================================================
# date32[pyarrow] as coordinate
# =============================================================================
def test_date32_coord_where():
dates = pd.date_range("2024-01-01", periods=3, freq="D")
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="val")
_ = da.where(da["time"] >= idx[1])
def test_date32_coord_concat():
dates = pd.date_range("2024-01-01", periods=3, freq="D")
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="val")
_ = xr.concat([da, da], dim="rep")
def test_date32_coord_align_vs_datetime64():
dates = pd.date_range("2024-01-01", periods=3, freq="D")
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da1 = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="v1")
da2 = xr.DataArray(
[1.0, 2.0, 3.0],
dims="time",
coords={"time": pd.date_range("2024-01-01", periods=3, freq="D")},
name="v2",
)
_ = xr.align(da1, da2, join="outer")
def test_date32_coord_where_null():
dates = [
pd.Timestamp("2024-01-01"),
pd.NaT,
pd.Timestamp("2024-01-03"),
]
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="val")
mask = ~da["time"].isnull()
_ = da.where(mask)
def test_date32_coord_values():
dates = pd.date_range("2024-01-01", periods=3, freq="D")
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="val")
_ = da["time"].values
def test_date32_coord_index_ops():
dates = pd.date_range("2024-01-01", periods=3, freq="D")
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([10.0, 20.0, 30.0], dims="time", coords={"time": idx}, name="v")
_ = da.sel(time=idx[1])
_ = da.sortby("time")
def test_date32_groupby_dropna():
dates = [
pd.Timestamp("2024-01-01"),
pd.Timestamp("2024-01-01"),
pd.NaT,
]
s = pd.Series(dates, name="date").astype("date32[pyarrow]")
idx = pd.Index(s, name="time")
da = xr.DataArray([1.0, 2.0, 3.0], dims="time", coords={"time": idx}, name="v")
_ = da.groupby("time").sum()
_ = da.dropna("time")
# =============================================================================
# Nullable Int64 and int64[pyarrow]
# =============================================================================
def test_int64_nullable_values_where():
s = pd.Series([1, 2, None], dtype="Int64", name="v")
da = s.to_xarray()
_ = da.where(da > 1)
def test_int64_nullable_values_concat():
s = pd.Series([1, 2, None], dtype="Int64", name="v")
da = s.to_xarray()
_ = xr.concat([da, da], dim="rep")
def test_int64_nullable_coord_align():
idx = pd.Index([1, 2, 3], dtype="Int64", name="i")
da1 = xr.DataArray([10, 20, 30], dims="i", coords={"i": idx}, name="v1")
da2 = xr.DataArray([100, 200], dims="i", coords={"i": [2, 4]}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_int64_nullable_values_where_null():
s = pd.Series([1, None, 3], dtype="Int64", name="v")
da = s.to_xarray()
mask = da.notnull()
_ = da.where(mask)
def test_int64_nullable_values_values():
s = pd.Series([1, 2, None], dtype="Int64", name="v")
da = s.to_xarray()
_ = da.values
def test_int64_nullable_arith_int64():
s = pd.Series([1, 2, None], dtype="Int64", name="v")
da_nullable = s.to_xarray()
da_int = xr.DataArray(
np.array([10, 20, 30], dtype="int64"),
dims="index",
coords={"index": [0, 1, 2]},
name="base",
)
# Broadcast to same dims for clarity
da_nullable = da_nullable.rename({"index": "index"})
out1 = da_nullable + da_int
out2 = da_int + da_nullable
expected = xr.DataArray(
np.array([11.0, 22.0, np.nan], dtype="float64"),
dims="index",
coords={"index": [0, 1, 2]},
name=None,
)
xrt.assert_allclose(out1, expected)
xrt.assert_allclose(out2, expected)
def test_int64_nullable_arith_float64():
s = pd.Series([1, 2, None], dtype="Int64", name="v")
da_nullable = s.to_xarray()
da_float = xr.DataArray(
np.array([0.5, 1.5, 2.5], dtype="float64"),
dims="index",
coords={"index": [0, 1, 2]},
name="f",
)
da_nullable = da_nullable.rename({"index": "index"})
out1 = da_nullable + da_float
out2 = da_float + da_nullable
expected = xr.DataArray(
np.array([1.5, 3.5, np.nan], dtype="float64"),
dims="index",
coords={"index": [0, 1, 2]},
name=None,
)
xrt.assert_allclose(out1, expected)
xrt.assert_allclose(out2, expected)
def test_Int64_coord_index_ops():
idx = pd.Index([2, 1, 3], dtype="Int64", name="i")
da = xr.DataArray([20, 10, 30], dims="i", coords={"i": idx}, name="v")
_ = da.sel(i=2)
_ = da.sel(i=[1, 3])
_ = da.loc[dict(i=1)]
_ = da.sortby("i")
def test_Int64_groupby_dropna():
idx = pd.Index([1, 2, 1, 3], dtype="Int64", name="i")
da = xr.DataArray([1.0, 2.0, 3.0, np.nan], dims="i", coords={"i": idx}, name="v")
_ = da.groupby("i").sum()
_ = da.dropna("i")
def test_Int64_values_sortby_self():
s = pd.Series([2, 1, None], dtype="Int64", name="v")
da = s.to_xarray()
_check_sortby_self_values(da)
def test_int64_pyarrow_values_where():
s = pd.Series([1, 2, None], dtype="int64[pyarrow]", name="v")
da = s.to_xarray()
_ = da.where(da > 1)
def test_int64_pyarrow_values_concat():
s = pd.Series([1, 2, None], dtype="int64[pyarrow]", name="v")
da = s.to_xarray()
_ = xr.concat([da, da], dim="rep")
def test_int64_pyarrow_coord_align():
idx = pd.Index([1, 2, 3], dtype="int64[pyarrow]", name="i_arrow")
da1 = xr.DataArray([5, 6, 7], dims="i_arrow", coords={"i_arrow": idx}, name="v1")
da2 = xr.DataArray([10, 20], dims="i_arrow", coords={"i_arrow": [2, 4]}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_int64_pyarrow_values_where_null():
s = pd.Series([1, None, 3], dtype="int64[pyarrow]", name="v")
da = s.to_xarray()
mask = da.notnull()
_ = da.where(mask)
def test_int64_pyarrow_values_values():
s = pd.Series([1, 2, None], dtype="int64[pyarrow]", name="v")
da = s.to_xarray()
_ = da.values
def test_int64_pyarrow_arith_int64():
s = pd.Series([1, 2, None], dtype="int64[pyarrow]", name="v")
da_nullable = s.to_xarray()
da_int = xr.DataArray(
np.array([10, 20, 30], dtype="int64"),
dims="index",
coords={"index": [0, 1, 2]},
name="base",
)
da_nullable = da_nullable.rename({"index": "index"})
out1 = da_nullable + da_int
out2 = da_int + da_nullable
expected = xr.DataArray(
np.array([11.0, 22.0, np.nan], dtype="float64"),
dims="index",
coords={"index": [0, 1, 2]},
name=None,
)
xrt.assert_allclose(out1, expected)
xrt.assert_allclose(out2, expected)
def test_int64_pyarrow_arith_float64():
s = pd.Series([1, 2, None], dtype="int64[pyarrow]", name="v")
da_nullable = s.to_xarray()
da_float = xr.DataArray(
np.array([0.5, 1.5, 2.5], dtype="float64"),
dims="index",
coords={"index": [0, 1, 2]},
name="f",
)
da_nullable = da_nullable.rename({"index": "index"})
out1 = da_nullable + da_float
out2 = da_float + da_nullable
expected = xr.DataArray(
np.array([1.5, 3.5, np.nan], dtype="float64"),
dims="index",
coords={"index": [0, 1, 2]},
name=None,
)
xrt.assert_allclose(out1, expected)
xrt.assert_allclose(out2, expected)
def test_int64_pyarrow_coord_index_ops():
idx = pd.Index([2, 1, 3], dtype="int64[pyarrow]", name="i_arrow")
da = xr.DataArray([20, 10, 30], dims="i_arrow", coords={"i_arrow": idx}, name="v")
_ = da.sel(i_arrow=2)
_ = da.sel(i_arrow=[1, 3])
_ = da.loc[dict(i_arrow=1)]
_ = da.sortby("i_arrow")
def test_int64_pyarrow_groupby_dropna():
idx = pd.Index([1, 2, 1, 3], dtype="int64[pyarrow]", name="i_arrow")
da = xr.DataArray([1.0, 2.0, 3.0, np.nan], dims="i_arrow", coords={"i_arrow": idx}, name="v")
_ = da.groupby("i_arrow").sum()
_ = da.dropna("i_arrow")
def test_int64_pyarrow_values_sortby_self():
s = pd.Series([2, 1, None], dtype="int64[pyarrow]", name="v")
da = s.to_xarray()
_check_sortby_self_values(da)
# =============================================================================
# Categorical as values and coordinate
# =============================================================================
def test_categorical_values_where():
cat = pd.Categorical(["a", "b", "a", "c"], categories=["a", "b", "c"])
da = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2, 3]}, name="cat_val")
_ = da.where(da != "a")
def test_categorical_values_concat():
cat = pd.Categorical(["a", "b", "a", "c"], categories=["a", "b", "c"])
da = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2, 3]}, name="cat_val")
_ = xr.concat([da, da], dim="rep")
def test_categorical_values_align_vs_object():
cat = pd.Categorical(["a", "b", "a", "c"], categories=["a", "b", "c"])
da1 = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2, 3]}, name="cat_val")
da2 = xr.DataArray(
np.array(["a", "c", "d"], dtype=object),
dims="x",
coords={"x": [0, 1, 2]},
name="obj_val",
)
_ = xr.align(da1, da2, join="outer")
def test_categorical_values_where_null():
cat = pd.Categorical(["a", None, "b", "c"], categories=["a", "b", "c"])
da = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2, 3]}, name="cat_val_na")
mask = ~da.isnull()
_ = da.where(mask)
def test_categorical_values_values():
cat = pd.Categorical(["a", "b", "a", "c"], categories=["a", "b", "c"])
da = xr.DataArray(cat, dims="x", name="cat_val")
_ = da.values
def test_categorical_values_dropna():
cat = pd.Categorical(["a", None, "b"], categories=["a", "b"])
da = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2]}, name="cat_val")
_ = da.dropna("x")
def test_categorical_values_sortby_self():
cat = pd.Categorical(["b", "a", None], categories=["a", "b"])
da = xr.DataArray(cat, dims="x", coords={"x": [0, 1, 2]}, name="cat_val")
_check_sortby_self_values(da)
def test_categorical_coord_where():
cat_idx = pd.CategoricalIndex(["A", "B", "C"], categories=["A", "B", "C"], name="lab")
da = xr.DataArray([1, 2, 3], dims="lab", coords={"lab": cat_idx}, name="v")
_ = da.where(da > 1)
def test_categorical_coord_concat():
cat_idx = pd.CategoricalIndex(["A", "B", "C"], categories=["A", "B", "C"], name="lab")
da = xr.DataArray([1, 2, 3], dims="lab", coords={"lab": cat_idx}, name="v")
_ = xr.concat([da, da], dim="rep")
def test_categorical_coord_align_vs_object_index():
idx_obj = pd.Index(["A", "B"], dtype="object", name="lab")
ds1 = xr.Dataset({"v": ("lab", [1, 2])}, coords={"lab": idx_obj})
idx_cat = pd.CategoricalIndex(["B", "C"], categories=["A", "B", "C"], name="lab")
ds2 = xr.Dataset({"v": ("lab", [3, 4])}, coords={"lab": idx_cat})
_ = xr.align(ds1, ds2, join="inner")
def test_categorical_coord_align_with_null():
cat_idx1 = pd.CategoricalIndex(
["A", None, "C"], categories=["A", "B", "C"], name="lab"
)
da1 = xr.DataArray([1, 2, 3], dims="lab", coords={"lab": cat_idx1}, name="v1")
cat_idx2 = pd.CategoricalIndex(
["A", "B"], categories=["A", "B", "C"], name="lab"
)
da2 = xr.DataArray([10, 20], dims="lab", coords={"lab": cat_idx2}, name="v2")
_ = xr.align(da1, da2, join="outer")
def test_categorical_coord_index_ops():
cat_idx = pd.CategoricalIndex(["C", "A", "B"], categories=["A", "B", "C"], name="lab")
da = xr.DataArray([3, 1, 2], dims="lab", coords={"lab": cat_idx}, name="v")
_ = da.sel(lab="B")
_ = da.sel(lab=["A", "C"])
_ = da.loc[dict(lab="A")]
_ = da.sortby("lab")
def test_categorical_coord_groupby_dropna():
cat_idx = pd.CategoricalIndex(["A", "B", None, "A"], categories=["A", "B", "C"], name="lab")
da = xr.DataArray([1.0, 2.0, 3.0, 4.0], dims="lab", coords={"lab": cat_idx}, name="v")
_ = da.groupby("lab").sum()
_ = da.dropna("lab")
# =============================================================================
# Main
# =============================================================================
def main():
print_header()
tests = [
# NumPy StringDType
("NumPy StringDType values: where", test_numpy_stringdtype_values_where),
("NumPy StringDType values: concat", test_numpy_stringdtype_values_concat),
("NumPy StringDType coord vs string: align", test_numpy_stringdtype_coord_align),
("NumPy StringDType values: where with null", test_numpy_stringdtype_values_where_null),
("NumPy StringDType values: .values", test_numpy_stringdtype_values_values),
("NumPy StringDType values: sortby(self)", test_numpy_stringdtype_values_sortby_self),
# string[pyarrow]
("string[pyarrow] values: where", test_string_pyarrow_values_where),
("string[pyarrow] values: concat", test_string_pyarrow_values_concat),
("string[pyarrow] values: align self vs slice", test_string_pyarrow_values_align),
("string[pyarrow] values: where with null", test_string_pyarrow_values_where_null),
("string[pyarrow] values: .values", test_string_pyarrow_values_values),
("string[pyarrow] values: sortby(self)", test_string_pyarrow_values_sortby_self),
("string[pyarrow] coord: where", test_string_pyarrow_coord_where),
("string[pyarrow] coord: concat", test_string_pyarrow_coord_concat),
("string[pyarrow] coord vs object: align", test_string_pyarrow_coord_align),
("string[pyarrow] coord vs object with null: align", test_string_pyarrow_coord_align_with_null),
("string[pyarrow] coord: sel/loc/sortby", test_string_pyarrow_coord_index_ops),
("string[pyarrow] coord: groupby/dropna", test_string_pyarrow_groupby_dropna),
# date32[pyarrow]
("date32[pyarrow] coord: where", test_date32_coord_where),
("date32[pyarrow] coord: concat", test_date32_coord_concat),
("date32[pyarrow] coord vs datetime64: align", test_date32_coord_align_vs_datetime64),
("date32[pyarrow] coord: where with null", test_date32_coord_where_null),
("date32[pyarrow] coord: .values", test_date32_coord_values),
("date32[pyarrow] coord: sel/sortby", test_date32_coord_index_ops),
("date32[pyarrow] coord: groupby/dropna", test_date32_groupby_dropna),
# Int64 / int64[pyarrow]
("Int64 values: where", test_int64_nullable_values_where),
("Int64 values: concat", test_int64_nullable_values_concat),
("Int64 coord vs int64: align", test_int64_nullable_coord_align),
("Int64 values: where with null", test_int64_nullable_values_where_null),
("Int64 values: .values", test_int64_nullable_values_values),
("Int64 values: + int64", test_int64_nullable_arith_int64),
("Int64 values: + float64", test_int64_nullable_arith_float64),
("Int64 coord: sel/loc/sortby", test_Int64_coord_index_ops),
("Int64 coord: groupby/dropna", test_Int64_groupby_dropna),
("Int64 values: sortby(self)", test_Int64_values_sortby_self),
("int64[pyarrow] values: where", test_int64_pyarrow_values_where),
("int64[pyarrow] values: concat", test_int64_pyarrow_values_concat),
("int64[pyarrow] coord vs int64: align", test_int64_pyarrow_coord_align),
("int64[pyarrow] values: where with null", test_int64_pyarrow_values_where_null),
("int64[pyarrow] values: .values", test_int64_pyarrow_values_values),
("int64[pyarrow] values: + int64", test_int64_pyarrow_arith_int64),
("int64[pyarrow] values: + float64", test_int64_pyarrow_arith_float64),
("int64[pyarrow] coord: sel/loc/sortby", test_int64_pyarrow_coord_index_ops),
("int64[pyarrow] coord: groupby/dropna", test_int64_pyarrow_groupby_dropna),
("int64[pyarrow] values: sortby(self)", test_int64_pyarrow_values_sortby_self),
# Categorical
("Categorical values: where", test_categorical_values_where),
("Categorical values: concat", test_categorical_values_concat),
("Categorical values vs object: align", test_categorical_values_align_vs_object),
("Categorical values: where with null", test_categorical_values_where_null),
("Categorical values: .values", test_categorical_values_values),
("Categorical values: dropna", test_categorical_values_dropna),
("Categorical values: sortby(self)", test_categorical_values_sortby_self),
("CategoricalIndex coord: where", test_categorical_coord_where),
("CategoricalIndex coord: concat", test_categorical_coord_concat),
("object vs CategoricalIndex coord: align", test_categorical_coord_align_vs_object_index),
("CategoricalIndex coord with null: align", test_categorical_coord_align_with_null),
("CategoricalIndex coord: sel/loc/sortby", test_categorical_coord_index_ops),
("CategoricalIndex coord: groupby/dropna", test_categorical_coord_groupby_dropna),
]
for name, func in tests:
run_test(name, func)
if __name__ == "__main__":
main()
And the output:
FAIL | NumPy StringDType values: where (DTypePromotionError: The DType <class 'numpy.dtypes.StringDType'> could not be promoted by <class 'numpy.dtypes._PyFloatDType'>. This means that no common DType exists for the given inputs. For example they cannot be stored in a single array unless the dtype is `object`. The full list of DTypes is: (<class 'numpy.dtypes.StringDType'>, <class 'numpy.dtypes._PyFloatDType'>))
PASS | NumPy StringDType values: concat
PASS | NumPy StringDType coord vs string: align
FAIL | NumPy StringDType values: where with null (TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtypes.StringDType'>.)
PASS | NumPy StringDType values: .values
PASS | NumPy StringDType values: sortby(self)
FAIL | string[pyarrow] values: where (TypeError: boolean value of NA is ambiguous)
PASS | string[pyarrow] values: concat
PASS | string[pyarrow] values: align self vs slice
PASS | string[pyarrow] values: where with null
PASS | string[pyarrow] values: .values
FAIL | string[pyarrow] values: sortby(self) (TypeError: boolean value of NA is ambiguous)
PASS | string[pyarrow] coord: where
PASS | string[pyarrow] coord: concat
PASS | string[pyarrow] coord vs object: align
PASS | string[pyarrow] coord vs object with null: align
PASS | string[pyarrow] coord: sel/loc/sortby
PASS | string[pyarrow] coord: groupby/dropna
PASS | date32[pyarrow] coord: where
PASS | date32[pyarrow] coord: concat
FAIL | date32[pyarrow] coord vs datetime64: align (TypeError: Cannot interpret 'date32[day][pyarrow]' as a data type)
PASS | date32[pyarrow] coord: where with null
PASS | date32[pyarrow] coord: .values
PASS | date32[pyarrow] coord: sel/sortby
PASS | date32[pyarrow] coord: groupby/dropna
FAIL | Int64 values: where (TypeError: Cannot interpret 'Int64Dtype()' as a data type)
PASS | Int64 values: concat
FAIL | Int64 coord vs int64: align (TypeError: Cannot interpret 'Int64Dtype()' as a data type)
FAIL | Int64 values: where with null (TypeError: Cannot interpret 'Int64Dtype()' as a data type)
PASS | Int64 values: .values
PASS | Int64 values: + int64
PASS | Int64 values: + float64
PASS | Int64 coord: sel/loc/sortby
PASS | Int64 coord: groupby/dropna
PASS | Int64 values: sortby(self)
FAIL | int64[pyarrow] values: where (TypeError: Cannot interpret 'int64[pyarrow]' as a data type)
FAIL | int64[pyarrow] values: concat (IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices)
FAIL | int64[pyarrow] coord vs int64: align (TypeError: Cannot interpret 'int64[pyarrow]' as a data type)
FAIL | int64[pyarrow] values: where with null (TypeError: Cannot interpret 'int64[pyarrow]' as a data type)
PASS | int64[pyarrow] values: .values
PASS | int64[pyarrow] values: + int64
PASS | int64[pyarrow] values: + float64
PASS | int64[pyarrow] coord: sel/loc/sortby
PASS | int64[pyarrow] coord: groupby/dropna
FAIL | int64[pyarrow] values: sortby(self) (TypeError: only integer scalar arrays can be converted to a scalar index)
FAIL | Categorical values: where (TypeError: Cannot interpret 'CategoricalDtype(categories=['a', 'b', 'c'], ordered=False, categories_dtype=object)' as a data type)
PASS | Categorical values: concat
PASS | Categorical values vs object: align
FAIL | Categorical values: where with null (TypeError: Cannot interpret 'CategoricalDtype(categories=['a', 'b', 'c'], ordered=False, categories_dtype=object)' as a data type)
PASS | Categorical values: .values
PASS | Categorical values: dropna
FAIL | Categorical values: sortby(self) (TypeError: '<' not supported between instances of 'float' and 'str')
PASS | CategoricalIndex coord: where
PASS | CategoricalIndex coord: concat
FAIL | object vs CategoricalIndex coord: align (TypeError: Cannot interpret 'CategoricalDtype(categories=['A', 'B', 'C'], ordered=False, categories_dtype=object)' as a data type)
FAIL | CategoricalIndex coord with null: align (TypeError: Cannot interpret 'CategoricalDtype(categories=['A', 'B', 'C'], ordered=False, categories_dtype=object)' as a data type)
PASS | CategoricalIndex coord: sel/loc/sortby
PASS | CategoricalIndex coord: groupby/dropna
I will have a look at this, thanks so much!
A few questions:
- Many failures of pyarrow string arrays are
TypeError: boolean value of NA is ambiguous- but this is fundamental topandas. We could start overriding this, but I think the warning is helpful, no? Wouldn't you want to know that you have NA values (which have ambiguous boolean interpretation)? Personally, it feels like this is not an error and I would want to know that this is the case. - For the categorical sortby, the categories aren't ordered so I probably wouldn't expect that to work as the test is written. However, if you were to order them I think we could implement something on the NEP-18 wrapper for
lexsort. Great find, hopefully it is really that simple! - The
CategoricalIndexstuff is unsurprising but can be addressed by probably just applying the fix #10423 there. Thanks! (@dcherian maybe you have other spots you're aware of where index behavior can get hairy? - The
StringDtypestuff I can look at, but looks unrealted to extension array stuff.
Thanks for the issue! Please chime in on point 1. anybody :)
@alippai it would be great if you could contribute a PR with these tests (xfailed if needed), suitably modified by @ilan-gold's feedback above.
@dcherian I've already got what mostly looks like a fix but @alippai you're welcome (encouraged? love to have more contributors) to take if over the finish line.
I'm currently doing the following:
diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py
index 40334d58..80f9a992 100644
--- a/xarray/core/dtypes.py
+++ b/xarray/core/dtypes.py
@@ -50,6 +50,10 @@ PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
(np.number, np.character), # numpy promotes to character
(np.bool_, np.character), # numpy promotes to character
(np.bytes_, np.str_), # numpy promotes to unicode
+ (
+ np.object_,
+ np.object_,
+ ), # default case, if only object is found, it should be promoted to
)
T_dtype = TypeVar("T_dtype", np.dtype, ExtensionDtype)
diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py
index c4f09144..5b88da08 100644
--- a/xarray/core/indexes.py
+++ b/xarray/core/indexes.py
@@ -12,6 +12,7 @@ import pandas as pd
from xarray.core import formatting, nputils, utils
from xarray.core.coordinate_transform import CoordinateTransform
+from xarray.core.dtypes import result_type
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.indexing import (
CoordinateTransformIndexingAdapter,
@@ -768,12 +769,10 @@ class PandasIndex(Index):
if not indexes:
coord_dtype = None
+ elif len(indexes_coord_dtypes := {idx.coord_dtype for idx in indexes}) == 1:
+ coord_dtype = next(iter(indexes_coord_dtypes))
else:
- indexes_coord_dtypes = {idx.coord_dtype for idx in indexes}
- if len(indexes_coord_dtypes) == 1:
- coord_dtype = next(iter(indexes_coord_dtypes))
- else:
- coord_dtype = np.result_type(*indexes_coord_dtypes)
+ coord_dtype = result_type(*indexes_coord_dtypes)
return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype)
@@ -895,8 +894,12 @@ class PandasIndex(Index):
else:
# how = "inner"
index = self.index.intersection(other.index)
-
- coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype)
+ # If pandas extension array machinerym already took care of this promotion for us, go with that result
+ coord_dtype = (
+ index.dtype
+ if utils.is_allowed_extension_array_dtype(index.dtype)
+ else result_type(self.coord_dtype, other.coord_dtype)
+ )
return type(self)(index, self.dim, coord_dtype=coord_dtype)
def reindex_like(
But it seems to have broken a string dtype test so working through that, but I probably won't finish today - Presumably the type promotion is wrong (if you see an object, maybe we promote down with strings in some cases?)
Update! There is some type promotion weirdness around strings that our internal result_type does that np.result_type does not do - see pytest "xarray/tests/test_dataset.py::TestDataset::test_align_str_dtype". Logging off for the day, but feel free to pick up.
I'll have limited time until the holidays, but I'll try to chip in.
- I agree this might not be interesting
- Cool
- Looks good indeed
- Might be not interesting
The two different align errors (pyarrow and the Categorical) maybe something worth checking.
Looks better after upgrading to 2025.12.0, 5 tests fixed:
xarray version: 2025.12.0
pandas version: 2.3.3
numpy version: 2.3.5
pyarrow version: 22.0.0
FAIL | NumPy StringDType values: where (DTypePromotionError: The DType <class 'numpy.dtypes.StringDType'> could not be promoted by <class 'numpy.dtypes._PyFloatDType'>. This means that no common DType exists for the given inputs. For example they cannot be stored in a single array unless the dtype is `object`. The full list of DTypes is: (<class 'numpy.dtypes.StringDType'>, <class 'numpy.dtypes._PyFloatDType'>))
PASS | NumPy StringDType values: concat
PASS | NumPy StringDType coord vs string: align
FAIL | NumPy StringDType values: where with null (TypeError: dtype argument must be a NumPy dtype, but it is a <class 'numpy.dtypes.StringDType'>.)
PASS | NumPy StringDType values: .values
PASS | NumPy StringDType values: sortby(self)
FAIL | string[pyarrow] values: where (TypeError: boolean value of NA is ambiguous)
PASS | string[pyarrow] values: concat
PASS | string[pyarrow] values: align self vs slice
PASS | string[pyarrow] values: where with null
PASS | string[pyarrow] values: .values
FAIL | string[pyarrow] values: sortby(self) (TypeError: boolean value of NA is ambiguous)
PASS | string[pyarrow] coord: where
PASS | string[pyarrow] coord: concat
PASS | string[pyarrow] coord vs object: align
PASS | string[pyarrow] coord vs object with null: align
PASS | string[pyarrow] coord: sel/loc/sortby
PASS | string[pyarrow] coord: groupby/dropna
PASS | date32[pyarrow] coord: where
PASS | date32[pyarrow] coord: concat
FAIL | date32[pyarrow] coord vs datetime64: align (TypeError: Cannot interpret 'date32[day][pyarrow]' as a data type)
PASS | date32[pyarrow] coord: where with null
PASS | date32[pyarrow] coord: .values
PASS | date32[pyarrow] coord: sel/sortby
PASS | date32[pyarrow] coord: groupby/dropna
FAIL | Int64 values: where (TypeError: boolean value of NA is ambiguous)
PASS | Int64 values: concat
FAIL | Int64 coord vs int64: align (TypeError: Cannot interpret 'Int64Dtype()' as a data type)
PASS | Int64 values: where with null
PASS | Int64 values: .values
PASS | Int64 values: + int64
PASS | Int64 values: + float64
PASS | Int64 coord: sel/loc/sortby
PASS | Int64 coord: groupby/dropna
PASS | Int64 values: sortby(self)
FAIL | int64[pyarrow] values: where (TypeError: boolean value of NA is ambiguous)
FAIL | int64[pyarrow] values: concat (IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices)
FAIL | int64[pyarrow] coord vs int64: align (TypeError: Cannot interpret 'int64[pyarrow]' as a data type)
PASS | int64[pyarrow] values: where with null
PASS | int64[pyarrow] values: .values
PASS | int64[pyarrow] values: + int64
PASS | int64[pyarrow] values: + float64
PASS | int64[pyarrow] coord: sel/loc/sortby
PASS | int64[pyarrow] coord: groupby/dropna
PASS | int64[pyarrow] values: sortby(self)
PASS | Categorical values: where
PASS | Categorical values: concat
PASS | Categorical values vs object: align
PASS | Categorical values: where with null
PASS | Categorical values: .values
PASS | Categorical values: dropna
FAIL | Categorical values: sortby(self) (TypeError: '<' not supported between instances of 'float' and 'str')
PASS | CategoricalIndex coord: where
PASS | CategoricalIndex coord: concat
FAIL | object vs CategoricalIndex coord: align (TypeError: Cannot interpret 'CategoricalDtype(categories=['A', 'B', 'C'], ordered=False, categories_dtype=object)' as a data type)
FAIL | CategoricalIndex coord with null: align (TypeError: Cannot interpret 'CategoricalDtype(categories=['A', 'B', 'C'], ordered=False, categories_dtype=object)' as a data type)
PASS | CategoricalIndex coord: sel/loc/sortby
PASS | CategoricalIndex coord: groupby/dropna```