pint-xarray
pint-xarray copied to clipboard
Chunked pint arrays break on `rolling()`
Hi folks,
I noticed that when running .rolling(...)
on a chunked pint
array, there is an exception raised that breaks the process:
TypeError: `pad_value` must be composed of integral typed values.
I outline three different cases below for running .rolling()
on a pint
-aware DataArray
.
- Calculating the rolling sum on an in-memory
pint
array.- Works, but loses units as expected (e.g. https://github.com/xarray-contrib/pint-xarray/issues/6#issuecomment-611134048). Although it seems like running it with
xr.set_options(use_bottleneck=False)
preserves units (https://github.com/pydata/xarray/issues/7062#issuecomment-1254047656).
- Works, but loses units as expected (e.g. https://github.com/xarray-contrib/pint-xarray/issues/6#issuecomment-611134048). Although it seems like running it with
- Calculating the rolling sum on a chunked
pint
array, usingxarray
chunking.- This works, even without turning off
bottleneck
. However, this isn't an optimal solution for me, since one cannot queryds.pint.units
on anxarray
-chunkedpint
array. I like being able to do that for various QOL checks in a data pipeline.
- This works, even without turning off
- Calculating the rolling sum on a
pint
array chunked withds.pint.chunk(...)
.- This method preserves the units, but leads to the traceback seen above and in full detail below. It also breaks when turning off
bottleneck
.
- This method preserves the units, but leads to the traceback seen above and in full detail below. It also breaks when turning off
import pint_xarray
import xarray as xr
print(xr.__version__)
>>> '2022.6.0'
print(pint_xarray.__version__)
>>> '0.3'
data = xr.DataArray(range(3), dims='time').pint.quantify('kelvin')
print(data)
>>> <xarray.DataArray (time: 3)>
>>> <Quantity([0 1 2], 'kelvin')>
# Case 1: rolling sum with `pint` units.
# Lose the units as expected, but executes properly.
rs = data.rolling(time=2).sum()
print(rs)
>>> <xarray.DataArray (time: 3)>
>>> array([nan, 1., 3.])
# Case 2: rolling sum with `xr.chunk()`
# Maintain the units after compute,
# but `data_xr_chunk.pint.units` returns `None` in the interim
data_xr_chunk = data.chunk({'time': 1})
rs = data_xr_chunk.rolling(time=2).sum().compute()
>>> <xarray.DataArray (time: 3)>
>>> <Quantity([nan 1. 3.], 'kelvin')>
# Case 3: rolling sum with `xr.pint.chunk()`
# Maintains units on chunked array, but raises exception
# (see full traceback below)
data_pint_chunk = data.pint.chunk({"time": 1})
rs = data_pint_chunk.rolling(time=2).sum().compute()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 rs = data_pint_chunk.rolling(time=2).sum().compute()
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:155, in Rolling._reduce_method.<locals>.method(self, keep_attrs, **kwargs)
151 def method(self, keep_attrs=None, **kwargs):
153 keep_attrs = self._get_keep_attrs(keep_attrs)
--> 155 return self._numpy_or_bottleneck_reduce(
156 array_agg_func,
157 bottleneck_move_func,
158 rolling_agg_func,
159 keep_attrs=keep_attrs,
160 fillna=fillna,
161 **kwargs,
162 )
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:589, in DataArrayRolling._numpy_or_bottleneck_reduce(self, array_agg_func, bottleneck_move_func, rolling_agg_func, keep_attrs, fillna, **kwargs)
586 kwargs.setdefault("skipna", False)
587 kwargs.setdefault("fillna", fillna)
--> 589 return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:472, in DataArrayRolling.reduce(self, func, keep_attrs, **kwargs)
470 else:
471 obj = self.obj
--> 472 windows = self._construct(
473 obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
474 )
476 result = windows.reduce(
477 func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs
478 )
480 # Find valid windows based on count.
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/rolling.py:389, in DataArrayRolling._construct(self, obj, window_dim, stride, fill_value, keep_attrs, **window_dim_kwargs)
384 window_dims = self._mapping_to_list(
385 window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506
386 )
387 strides = self._mapping_to_list(stride, default=1)
--> 389 window = obj.variable.rolling_window(
390 self.dim, self.window, window_dims, self.center, fill_value=fill_value
391 )
393 attrs = obj.attrs if keep_attrs else {}
395 result = DataArray(
396 window,
397 dims=obj.dims + tuple(window_dims),
(...)
400 name=obj.name,
401 )
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/variable.py:2314, in Variable.rolling_window(self, dim, window, window_dim, center, fill_value)
2311 else:
2312 pads[d] = (win - 1, 0)
-> 2314 padded = var.pad(pads, mode="constant", constant_values=fill_value)
2315 axis = [self.get_axis_num(d) for d in dim]
2316 new_dims = self.dims + tuple(window_dim)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/xarray/core/variable.py:1416, in Variable.pad(self, pad_width, mode, stat_length, constant_values, end_values, reflect_type, **pad_width_kwargs)
1413 if reflect_type is not None:
1414 pad_option_kwargs["reflect_type"] = reflect_type
-> 1416 array = np.pad( # type: ignore[call-overload]
1417 self.data.astype(dtype, copy=False),
1418 pad_width_by_index,
1419 mode=mode,
1420 **pad_option_kwargs,
1421 )
1423 return type(self)(self.dims, array)
File <__array_function__ internals>:180, in pad(*args, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/quantity.py:1730, in Quantity.__array_function__(self, func, types, args, kwargs)
1729 def __array_function__(self, func, types, args, kwargs):
-> 1730 return numpy_wrap("function", func, args, kwargs, types)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/numpy_func.py:936, in numpy_wrap(func_type, func, args, kwargs, types)
934 if name not in handled or any(is_upcast_type(t) for t in types):
935 return NotImplemented
--> 936 return handled[name](*args, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/pint/numpy_func.py:660, in _pad(array, pad_width, mode, **kwargs)
656 if key in kwargs:
657 kwargs[key] = _recursive_convert(kwargs[key], units)
659 return units._REGISTRY.Quantity(
--> 660 np.pad(array._magnitude, pad_width, mode=mode, **kwargs), units
661 )
File <__array_function__ internals>:180, in pad(*args, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/core.py:1762, in Array.__array_function__(self, func, types, args, kwargs)
1759 if has_keyword(da_func, "like"):
1760 kwargs["like"] = self
-> 1762 return da_func(*args, **kwargs)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:1229, in pad(array, pad_width, mode, **kwargs)
1227 elif mode == "constant":
1228 kwargs.setdefault("constant_values", 0)
-> 1229 return pad_edge(array, pad_width, mode, **kwargs)
1230 elif mode == "linear_ramp":
1231 kwargs.setdefault("end_values", 0)
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:964, in pad_edge(array, pad_width, mode, **kwargs)
957 def pad_edge(array, pad_width, mode, **kwargs):
958 """
959 Helper function for padding edges.
960
961 Handles the cases where the only the values on the edge are needed.
962 """
--> 964 kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
966 result = array
967 for d in range(array.ndim):
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:964, in <dictcomp>(.0)
957 def pad_edge(array, pad_width, mode, **kwargs):
958 """
959 Helper function for padding edges.
960
961 Handles the cases where the only the values on the edge are needed.
962 """
--> 964 kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
966 result = array
967 for d in range(array.ndim):
File ~/miniconda3/envs/analysis_py39/lib/python3.9/site-packages/dask/array/creation.py:910, in expand_pad_value(array, pad_value)
908 pad_value = array.ndim * (tuple(pad_value[0]),)
909 else:
--> 910 raise TypeError("`pad_value` must be composed of integral typed values.")
912 return pad_value
TypeError: `pad_value` must be composed of integral typed values.
My solution in the interim is to do something like:
units = data.pint.units
data = data.pint.dequantify()
rs = data.rolling(time=2)
rs = rs.pint.quantify(units)
Another side note that came up here -- I'm curious if there's any roadmap plan for recognizing integration of units for methods like rolling().sum()
.
E.g.,
data = xr.DataArray(range(3), dims='time').pint.quantify('mm/day')
data.pint.units
>>> mm/day
data = data.rolling(time=2).sum()
data.pint.units
>>> mm
thanks for the report, @riley-brady. It seems that xarray
operations on pint
+dask
are not as thoroughly tested as pint
and dask
on their own. I think this is a bug in pint
(or dask
, not sure): we enable force_ndarray_like
to convert scalars to 0d arrays, which means that the final call to np.pad
becomes:
np.pad(magnitude, pad_width, mode="constant", constant_values=np.array(0))
numpy
seems to be fine with that, but dask
is not.
@jrbourbeau, what do you think? Would it make sense to extend expand_pad_value
to unpack 0d arrays (using .item()
), or would you rather have the caller (pint
, in this case) do that?
I'm curious if there's any roadmap plan for recognizing integration of units for methods like
rolling().sum()
I'm not sure I follow. Why would rolling().sum()
work similar to integration, when all it does is compute a grouped sum? I'm not sure if this actually counts as integration, but you can multiply the result of the rolling sum with the diff
of the time coordinate (which is a bit tricky because time
is an indexed coordinate):
data = xr.DataArray(
range(3), dims="time", coords={"time2": ("time", [1, 2, 3])}
).pint.quantify("mm/day", time="day")
dt = data.time2.pad(time=(1, 0)).diff(dim="time")
data.rolling(time=2).sum() * dt
and then you would have the correct units (with the same numerical result, because I chose the time coordinate to have increments of 1 day
)
Thanks for the quick feedback on this issue @keewis.
Also thanks for the demo with .diff()
. You're right about the integration assumptions. In my specific use case I am doing a rolling sum of units mm/day
with daily time steps, so in this case it should reflect total precip in mm
, but that's not a fair assumption for many other cases. I'll give the .diff()
method a try.
this should be fixed in dask
since quite a while ago, but I'll leave it open until we have tests for this (probably after copying the test suite from xarray
)