pint-xarray icon indicating copy to clipboard operation
pint-xarray copied to clipboard

where operation on quantified gives RecursionError

Open raybellwaves opened this issue 3 years ago • 2 comments

import xarray as xr
import cf_xarray.units
import pint_xarray
from pint_xarray import unit_registry as ureg

xr.set_options(display_expand_data=False)

ds = xr.tutorial.open_dataset("air_temperature")

data = ds.air
quantified = data.pint.quantify()

expected = ds.air.where(ds.air.max())

quantified.where(quantified.max())
---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
<ipython-input-11-9572a4a0e70e> in <module>
----> 1 quantified.where(quantified.max())

~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/common.py in where(self, cond, other, drop)
   1284             cond = cond.isel(**indexers)
   1285
-> 1286         return ops.where_method(self, cond, other)
   1287
   1288     def set_close(self, close: Optional[Callable[[], None]]) -> None:

~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/ops.py in where_method(self, cond, other)
    174     # alignment for three arguments is complicated, so don't support it yet
    175     join = "inner" if other is dtypes.NA else "exact"
--> 176     return apply_ufunc(
    177         duck_array_ops.where_method,
    178         self,

~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
   1172     # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
   1173     elif any(isinstance(a, DataArray) for a in args):
-> 1174         return apply_dataarray_vfunc(
   1175             variables_vfunc,
   1176             *args,

~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
    291
    292     data_vars = [getattr(a, "variable", a) for a in args]
--> 293     result_var = func(*data_vars)
    294
    295     if signature.num_outputs > 1:

~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
    740             )
    741
--> 742     result_data = func(*input_data)
    743
    744     if signature.num_outputs == 1:

~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/duck_array_ops.py in where_method(data, cond, other)
    288     if other is dtypes.NA:
    289         other = dtypes.get_fill_value(data.dtype)
--> 290     return where(cond, data, other)
    291
    292

~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/duck_array_ops.py in where(condition, x, y)
    282 def where(condition, x, y):
    283     """Three argument where() with better dtype promotion rules."""
--> 284     return _where(condition, *as_shared_dtype([x, y]))
    285
    286

~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/duck_array_ops.py in f(*args, **kwargs)
     54             else:
     55                 wrapped = getattr(eager_module, name)
---> 56             return wrapped(*args, **kwargs)
     57
     58     else:

<__array_function__ internals> in where(*args, **kwargs)

~/miniconda3/envs/main/lib/python3.9/site-packages/pint/quantity.py in __array_function__(self, func, types, args, kwargs)
   1656
   1657     def __array_function__(self, func, types, args, kwargs):
-> 1658         return numpy_wrap("function", func, args, kwargs, types)
   1659
   1660     _wrapped_numpy_methods = ["flatten", "astype", "item"]

~/miniconda3/envs/main/lib/python3.9/site-packages/pint/numpy_func.py in numpy_wrap(func_type, func, args, kwargs, types)
    919     if name not in handled or any(is_upcast_type(t) for t in types):
    920         return NotImplemented
--> 921     return handled[name](*args, **kwargs)

~/miniconda3/envs/main/lib/python3.9/site-packages/pint/numpy_func.py in _where(condition, *args)
    552 def _where(condition, *args):
    553     args, output_wrap = unwrap_and_wrap_consistent_units(*args)
--> 554     return output_wrap(np.where(condition, *args))
    555
    556

... last 4 frames repeated, from the frame below ...

<__array_function__ internals> in where(*args, **kwargs)

RecursionError: maximum recursion depth exceeded

raybellwaves avatar Sep 07 '21 02:09 raybellwaves

thanks for the report, @raybellwaves, that might actually be a bug in pint's code wrapping np.where, or xarray does not unpack the xarray argument before calling the data's where.

I'll investigate.

keewis avatar Sep 07 '21 08:09 keewis

It took me quite a while, but I can now confirm that this is a bug in pint:

import pint
import numpy as np

q = pint.Quantity([2, 3, 6, 5], "kelvin")
np.where(q.max(), q)

Regardless of whether or not this actually makes sense, getting a RecursionError is not the appropriate error for this.

I'll forward this to pint's issue tracker.

keewis avatar Apr 15 '22 10:04 keewis

this has been fixed in pint and is part of pint>=0.20

keewis avatar Apr 14 '23 10:04 keewis