xarray icon indicating copy to clipboard operation
xarray copied to clipboard

Fix NA semantics for numpy.StringDType

Open alippai opened this issue 3 months ago • 0 comments

What happened?

StringDType() with na_object=np.nan and na_object=None is only partially or not supported.

What did you expect to happen?

Similar behavior to python object strings

Minimal Complete Verifiable Example

# /// script
# requires-python = ">=3.11"
# dependencies = [
#   "xarray[complete]@git+https://github.com/pydata/xarray.git@main",
# ]
# ///
#
# This script automatically imports the development branch of xarray to check for issues.
# Please delete this header if you have _not_ tested this script with `uv run`!

import xarray as xr
xr.show_versions()
# your reproducer code ...
"""
Check how xarray handles NumPy StringDType with different na_object values.

It compares:
  - object-typed DataArray / Dataset (baseline)
  - StringDType-typed DataArray / Dataset (under test)

For each case it runs:
  - .isnull()
  - .count()
  - .dropna(dim="x")
  - .where(~isnull, other="filled")

and records success / failure (with the error message, including na_object).
"""

from __future__ import annotations

import numpy as np
import xarray as xr
from numpy.dtypes import StringDType
from textwrap import shorten

NA_OBJECTS = [np.nan, None]
OPS = ["isnull", "count", "dropna", "where"]


def _short(obj, width=40):
    return shorten(repr(obj), width=width, placeholder="…")


def run_op(kind: str, da_or_var: xr.DataArray):
    """
    Run the four operations on a DataArray (or Dataset variable) and
    return a dict: {op_name: (ok: bool, detail: str)}.
    """
    results: dict[str, tuple[bool, str]] = {}

    # isnull
    try:
        res = da_or_var.isnull()
        results["isnull"] = (True, f"mask={_short(res.values)}")
    except Exception as e:  # noqa: BLE001
        results["isnull"] = (False, f"{type(e).__name__}: {_short(e)}")

    # count
    try:
        res = da_or_var.count()
        results["count"] = (True, f"value={_short(res.values)}")
    except Exception as e:  # noqa: BLE001
        results["count"] = (False, f"{type(e).__name__}: {_short(e)}")

    # dropna
    try:
        res = da_or_var.dropna(dim="x")
        results["dropna"] = (True, f"values={_short(res.values)}")
    except Exception as e:  # noqa: BLE001
        results["dropna"] = (False, f"{type(e).__name__}: {_short(e)}")

    # where(~isnull, "filled")
    try:
        mask = da_or_var.isnull()
        res = da_or_var.where(~mask, other="filled")
        results["where"] = (True, f"values={_short(res.values)}")
    except Exception as e:  # noqa: BLE001
        results["where"] = (False, f"{type(e).__name__}: {_short(e)}")

    return results


def print_case_header(title: str, arr: np.ndarray, dtype, na_obj):
    print(f"\n=== {title} ===")
    print(f"  storage dtype: {dtype!r} (type={type(dtype)})")
    print(f"  na_object: {na_obj!r}")
    print(f"  values: {arr!r}")


def print_results_table(results: dict[str, tuple[bool, str]]):
    # Simple ASCII table: op | OK/FAIL | detail
    col_op = max(len("op"), max(len(op) for op in OPS))
    col_ok = len("status")
    col_detail = 60

    header = f"  {'op'.ljust(col_op)}  {'status'.ljust(col_ok)}  detail"
    sep = "  " + "-" * col_op + "  " + "-" * col_ok + "  " + "-" * col_detail
    print(header)
    print(sep)

    for op in OPS:
        ok, detail = results[op]
        status = "OK" if ok else "FAIL"
        print(f"  {op.ljust(col_op)}  {status.ljust(col_ok)}  {_short(detail, col_detail)}")


def main():
    print("NumPy version:", np.__version__)
    print("xarray version:", xr.__version__)
    print("============================================================")

    summary_rows = []

    # 1) Baseline: object dtype (no custom na_object)
    arr_obj = np.array(["a", np.nan, "c"], dtype=object)

    # DataArray baseline
    da_obj = xr.DataArray(arr_obj, dims="x", name="str_obj")
    print_case_header("DataArray – object dtype (baseline)", da_obj.values, da_obj.dtype, na_obj=None)
    res_da_obj = run_op("DataArray-object", da_obj)
    print_results_table(res_da_obj)
    summary_rows.append(
        ("DataArray", "object", "-", sum(r[0] for r in res_da_obj.values()), sum(not r[0] for r in res_da_obj.values()))
    )

    # Dataset baseline
    ds_obj = xr.Dataset({"str": ("x", arr_obj)})
    var_obj = ds_obj["str"]
    print_case_header("Dataset['str'] – object dtype (baseline)", var_obj.values, var_obj.dtype, na_obj=None)
    res_ds_obj = run_op("Dataset-object", var_obj)
    print_results_table(res_ds_obj)
    summary_rows.append(
        ("Dataset['str']", "object", "-", sum(r[0] for r in res_ds_obj.values()), sum(not r[0] for r in res_ds_obj.values()))
    )

    # 2) StringDType with different na_object values
    for na_obj in NA_OBJECTS:
        dt = StringDType(na_object=na_obj)
        arr_str = np.array(["a", na_obj, "c"], dtype=dt)

        # DataArray with StringDType
        da_str = xr.DataArray(arr_str, dims="x", name=f"str_StringDType_{type(na_obj).__name__}")
        print_case_header(f"DataArray – StringDType(na_object={na_obj!r})", da_str.values, da_str.dtype, na_obj)
        res_da_str = run_op("DataArray-StringDType", da_str)
        print_results_table(res_da_str)
        summary_rows.append(
            (
                "DataArray",
                "StringDType",
                repr(na_obj),
                sum(r[0] for r in res_da_str.values()),
                sum(not r[0] for r in res_da_str.values()),
            )
        )

        # Dataset with StringDType
        ds_str = xr.Dataset({"str": ("x", arr_str)})
        var_str = ds_str["str"]
        print_case_header(
            f"Dataset['str'] – StringDType(na_object={na_obj!r})",
            var_str.values,
            var_str.dtype,
            na_obj,
        )
        res_ds_str = run_op("Dataset-StringDType", var_str)
        print_results_table(res_ds_str)
        summary_rows.append(
            (
                "Dataset['str']",
                "StringDType",
                repr(na_obj),
                sum(r[0] for r in res_ds_str.values()),
                sum(not r[0] for r in res_ds_str.values()),
            )
        )

    # Summary matrix
    print("\n==================== SUMMARY ====================")
    col_scope = max(len("scope"), max(len(r[0]) for r in summary_rows))
    col_storage = max(len("storage"), max(len(r[1]) for r in summary_rows))
    col_naobj = max(len("na_object"), max(len(r[2]) for r in summary_rows))

    header = (
        f"{'scope'.ljust(col_scope)}  "
        f"{'storage'.ljust(col_storage)}  "
        f"{'na_object'.ljust(col_naobj)}  "
        f"{'#OK':>3}  "
        f"{'#FAIL':>5}"
    )
    sep = "-" * len(header)
    print(header)
    print(sep)
    for scope, storage, na_obj_repr, ok_count, fail_count in summary_rows:
        print(
            f"{scope.ljust(col_scope)}  "
            f"{storage.ljust(col_storage)}  "
            f"{na_obj_repr.ljust(col_naobj)}  "
            f"{ok_count:>3}  "
            f"{fail_count:>5}"
        )


if __name__ == "__main__":
    main()

Steps to reproduce

No response

MVCE confirmation

  • [x] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • [x] Complete example — the example is self-contained, including all data and the text of any traceback.
  • [x] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • [x] New issue — a search of GitHub Issues suggests this is not a duplicate.
  • [x] Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

NumPy version: 2.3.5                                                                                                                                                                                                                                                                                                  
xarray version: 2025.11.0
============================================================

=== DataArray – object dtype (baseline) ===
  storage dtype: dtype('O') (type=<class 'numpy.dtypes.ObjectDType'>)
  na_object: None
  values: array(['a', nan, 'c'], dtype=object)
  op      status  detail
  ------  ------  ------------------------------------------------------------
  isnull  OK      'mask=array([False, True, False])'
  count   OK      'value=array(2)'
  dropna  OK      "values=array(['a', 'c'], dtype=object)"
  where   OK      "values=array(['a', 'filled', 'c'],…"

=== Dataset['str'] – object dtype (baseline) ===
  storage dtype: dtype('O') (type=<class 'numpy.dtypes.ObjectDType'>)
  na_object: None
  values: array(['a', nan, 'c'], dtype=object)
  op      status  detail
  ------  ------  ------------------------------------------------------------
  isnull  OK      'mask=array([False, True, False])'
  count   OK      'value=array(2)'
  dropna  OK      "values=array(['a', 'c'], dtype=object)"
  where   OK      "values=array(['a', 'filled', 'c'],…"

=== DataArray – StringDType(na_object=nan) ===
  storage dtype: StringDType(na_object=nan) (type=<class 'numpy.dtypes.StringDType'>)
  na_object: nan
  values: array(['a', nan, 'c'], dtype=StringDType(na_object=nan))
  op      status  detail
  ------  ------  ------------------------------------------------------------
  isnull  FAIL    'TypeError: TypeError("dtype argument must be a…'
  count   FAIL    'TypeError: TypeError("dtype argument must be a…'
  dropna  FAIL    'TypeError: TypeError("dtype argument must be a…'
  where   FAIL    'TypeError: TypeError("dtype argument must be a…'

=== Dataset['str'] – StringDType(na_object=nan) ===
  storage dtype: StringDType(na_object=nan) (type=<class 'numpy.dtypes.StringDType'>)
  na_object: nan
  values: array(['a', nan, 'c'], dtype=StringDType(na_object=nan))
  op      status  detail
  ------  ------  ------------------------------------------------------------
  isnull  FAIL    'TypeError: TypeError("dtype argument must be a…'
  count   FAIL    'TypeError: TypeError("dtype argument must be a…'
  dropna  FAIL    'TypeError: TypeError("dtype argument must be a…'
  where   FAIL    'TypeError: TypeError("dtype argument must be a…'

=== DataArray – StringDType(na_object=None) ===
  storage dtype: StringDType(na_object=None) (type=<class 'numpy.dtypes.StringDType'>)
  na_object: None
  values: array(['a', None, 'c'], dtype=StringDType(na_object=None))
  op      status  detail
  ------  ------  ------------------------------------------------------------
  isnull  FAIL    'TypeError: TypeError("dtype argument must be a…'
  count   FAIL    'TypeError: TypeError("dtype argument must be a…'
  dropna  FAIL    'TypeError: TypeError("dtype argument must be a…'
  where   FAIL    'TypeError: TypeError("dtype argument must be a…'

=== Dataset['str'] – StringDType(na_object=None) ===
  storage dtype: StringDType(na_object=None) (type=<class 'numpy.dtypes.StringDType'>)
  na_object: None
  values: array(['a', None, 'c'], dtype=StringDType(na_object=None))
  op      status  detail
  ------  ------  ------------------------------------------------------------
  isnull  FAIL    'TypeError: TypeError("dtype argument must be a…'
  count   FAIL    'TypeError: TypeError("dtype argument must be a…'
  dropna  FAIL    'TypeError: TypeError("dtype argument must be a…'
  where   FAIL    'TypeError: TypeError("dtype argument must be a…'

==================== SUMMARY ====================
scope           storage      na_object  #OK  #FAIL
--------------------------------------------------
DataArray       object       -            4      0
Dataset['str']  object       -            4      0
DataArray       StringDType  nan          0      4
Dataset['str']  StringDType  nan          0      4
DataArray       StringDType  None         0      4
Dataset['str']  StringDType  None         0      4

Anything else we need to know?

No response

Environment

NumPy version: 2.3.5 xarray version: 2025.11.0

alippai avatar Dec 02 '25 05:12 alippai