`rolling(...).construct(...)` blows up chunk size
What happened?
When using `rolling(...).construct(...) in https://github.com/coiled/benchmarks/pull/1552, I noticed that my Dask workers died running out of memory because the chunk sizes get blown up.
What did you expect to happen?
Naively, I would expect rolling(...).construct(...) to try and keep chunk sizes constant instead of blowing them up quadratic to the window size.
Minimal Complete Verifiable Example
import dask.array as da
import xarray as xr
# Construct dataset with chunk size of (400, 400, 1) or 1.22 MiB
ds = xr.Dataset(
dict(
foo=(
["latitute", "longitude", "time"],
da.random.random((400, 400, 400), chunks=(-1, -1, 1)),
),
)
)
# Dataset now has chunks of size (400, 400, 100 100) or 11.92 GiB
ds = ds.rolling(time=100, center=True).construct("window")
MVCE confirmation
- [X] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- [X] Complete example — the example is self-contained, including all data and the text of any traceback.
- [X] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- [X] New issue — a search of GitHub Issues suggests this is not a duplicate.
- [X] Recent environment — the issue occurs with the latest version of xarray and its dependencies.
Relevant log output
No response
Anything else we need to know?
No response
Environment
xarray: 2024.7.0 pandas: 2.2.2 numpy: 1.26.4 scipy: 1.14.0 netCDF4: None pydap: None h5netcdf: None h5py: None zarr: 2.18.2 cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: 1.4.0 dask: 2024.9.0 distributed: 2024.9.0 matplotlib: None cartopy: None seaborn: None numbagg: None fsspec: 2024.6.1 cupy: None pint: None sparse: 0.15.4 flox: 0.9.9 numpy_groupies: 0.11.2 setuptools: 73.0.1 pip: 24.2 conda: 24.7.1 pytest: 8.3.3 mypy: None IPython: 8.27.0 sphinx: None
Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!
This is using the sliding_window_view trick under the hood, which composes badly with anything that does a memory copy (like weighted in your example)
https://github.com/dask/dask/blob/d45ea380eb55feac74e8146e8ff7c6261e93b9d7/dask/array/overlap.py#L808
We actually use this approach for .rolling.mean but are clever about handling memory copies under the hood (https://github.com/pydata/xarray/pull/4915).
I'm not sure what the right solution here is.
- Perhaps dask can automatically rechunk the dimensions that are being "slided over"? We'd want the new dimensions
"window"to be singly-chunked by default I think. - On the xarray side, a lot of the pain stems from automatically padding with
NaNs inrolling.construct. This has downstream consequences (np.nanmeanuses a memory copy for example). But this is a more complex fix: https://github.com/pydata/xarray/pull/5603
PS: I chatted with @phofl about this at FOSS4G. He has some context.
Yeah this is definitely on my todo list and @hendrikmakait and I chatted briefly about this today, there is definitely something we have to do
I support the approach, but it'd be good to see the impact on ds.rolling().mean() which also uses construct but is clever about it to avoid the memory blowup.
I also wonder if instead of using rolling().construct().weighted().mean() there should just be something like rolling().weighted().mean() or rolling().mean(weights=...). From what I understand, the quadratic explosion of the shape and the chunks is not inherent to this computation but we could also solve it akin to a map_overlap computation.
Yes, https://github.com/pydata/xarray/issues/3937, but we've struggled to move on that.
construct is a pretty useful escape hatch for custom workloads, so we should optimize for it behaving sanely.