pint-xarray
pint-xarray copied to clipboard
where operation on quantified gives RecursionError
import xarray as xr
import cf_xarray.units
import pint_xarray
from pint_xarray import unit_registry as ureg
xr.set_options(display_expand_data=False)
ds = xr.tutorial.open_dataset("air_temperature")
data = ds.air
quantified = data.pint.quantify()
expected = ds.air.where(ds.air.max())
quantified.where(quantified.max())
---------------------------------------------------------------------------
RecursionError Traceback (most recent call last)
<ipython-input-11-9572a4a0e70e> in <module>
----> 1 quantified.where(quantified.max())
~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/common.py in where(self, cond, other, drop)
1284 cond = cond.isel(**indexers)
1285
-> 1286 return ops.where_method(self, cond, other)
1287
1288 def set_close(self, close: Optional[Callable[[], None]]) -> None:
~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/ops.py in where_method(self, cond, other)
174 # alignment for three arguments is complicated, so don't support it yet
175 join = "inner" if other is dtypes.NA else "exact"
--> 176 return apply_ufunc(
177 duck_array_ops.where_method,
178 self,
~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
1172 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
1173 elif any(isinstance(a, DataArray) for a in args):
-> 1174 return apply_dataarray_vfunc(
1175 variables_vfunc,
1176 *args,
~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
291
292 data_vars = [getattr(a, "variable", a) for a in args]
--> 293 result_var = func(*data_vars)
294
295 if signature.num_outputs > 1:
~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
740 )
741
--> 742 result_data = func(*input_data)
743
744 if signature.num_outputs == 1:
~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/duck_array_ops.py in where_method(data, cond, other)
288 if other is dtypes.NA:
289 other = dtypes.get_fill_value(data.dtype)
--> 290 return where(cond, data, other)
291
292
~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/duck_array_ops.py in where(condition, x, y)
282 def where(condition, x, y):
283 """Three argument where() with better dtype promotion rules."""
--> 284 return _where(condition, *as_shared_dtype([x, y]))
285
286
~/miniconda3/envs/main/lib/python3.9/site-packages/xarray/core/duck_array_ops.py in f(*args, **kwargs)
54 else:
55 wrapped = getattr(eager_module, name)
---> 56 return wrapped(*args, **kwargs)
57
58 else:
<__array_function__ internals> in where(*args, **kwargs)
~/miniconda3/envs/main/lib/python3.9/site-packages/pint/quantity.py in __array_function__(self, func, types, args, kwargs)
1656
1657 def __array_function__(self, func, types, args, kwargs):
-> 1658 return numpy_wrap("function", func, args, kwargs, types)
1659
1660 _wrapped_numpy_methods = ["flatten", "astype", "item"]
~/miniconda3/envs/main/lib/python3.9/site-packages/pint/numpy_func.py in numpy_wrap(func_type, func, args, kwargs, types)
919 if name not in handled or any(is_upcast_type(t) for t in types):
920 return NotImplemented
--> 921 return handled[name](*args, **kwargs)
~/miniconda3/envs/main/lib/python3.9/site-packages/pint/numpy_func.py in _where(condition, *args)
552 def _where(condition, *args):
553 args, output_wrap = unwrap_and_wrap_consistent_units(*args)
--> 554 return output_wrap(np.where(condition, *args))
555
556
... last 4 frames repeated, from the frame below ...
<__array_function__ internals> in where(*args, **kwargs)
RecursionError: maximum recursion depth exceeded
thanks for the report, @raybellwaves, that might actually be a bug in pint
's code wrapping np.where
, or xarray
does not unpack the xarray
argument before calling the data's where
.
I'll investigate.
It took me quite a while, but I can now confirm that this is a bug in pint
:
import pint
import numpy as np
q = pint.Quantity([2, 3, 6, 5], "kelvin")
np.where(q.max(), q)
Regardless of whether or not this actually makes sense, getting a RecursionError
is not the appropriate error for this.
I'll forward this to pint
's issue tracker.
this has been fixed in pint
and is part of pint>=0.20