xarray icon indicating copy to clipboard operation
xarray copied to clipboard

`rolling(...).construct(...)` blows up chunk size

Open hendrikmakait opened this issue 1 year ago • 6 comments

What happened?

When using `rolling(...).construct(...) in https://github.com/coiled/benchmarks/pull/1552, I noticed that my Dask workers died running out of memory because the chunk sizes get blown up.

What did you expect to happen?

Naively, I would expect rolling(...).construct(...) to try and keep chunk sizes constant instead of blowing them up quadratic to the window size.

Minimal Complete Verifiable Example

import dask.array as da
import xarray as xr

# Construct dataset with chunk size of (400, 400, 1) or 1.22 MiB
ds = xr.Dataset(
        dict(
            foo=(
                ["latitute", "longitude", "time"],
                da.random.random((400, 400, 400), chunks=(-1, -1, 1)),
            ),
        )
    )

# Dataset now has chunks of size (400, 400, 100 100) or 11.92 GiB
ds = ds.rolling(time=100, center=True).construct("window")

MVCE confirmation

  • [X] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • [X] Complete example — the example is self-contained, including all data and the text of any traceback.
  • [X] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • [X] New issue — a search of GitHub Issues suggests this is not a duplicate.
  • [X] Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

No response

Environment

INSTALLED VERSIONS ------------------ commit: None python: 3.12.6 | packaged by conda-forge | (main, Sep 11 2024, 04:55:15) [Clang 17.0.6 ] python-bits: 64 OS: Darwin OS-release: 23.6.0 machine: arm64 processor: arm byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: None libnetcdf: None

xarray: 2024.7.0 pandas: 2.2.2 numpy: 1.26.4 scipy: 1.14.0 netCDF4: None pydap: None h5netcdf: None h5py: None zarr: 2.18.2 cftime: 1.6.4 nc_time_axis: None iris: None bottleneck: 1.4.0 dask: 2024.9.0 distributed: 2024.9.0 matplotlib: None cartopy: None seaborn: None numbagg: None fsspec: 2024.6.1 cupy: None pint: None sparse: 0.15.4 flox: 0.9.9 numpy_groupies: 0.11.2 setuptools: 73.0.1 pip: 24.2 conda: 24.7.1 pytest: 8.3.3 mypy: None IPython: 8.27.0 sphinx: None

hendrikmakait avatar Sep 26 '24 11:09 hendrikmakait

Thanks for opening your first issue here at xarray! Be sure to follow the issue template! If you have an idea for a solution, we would really welcome a Pull Request with proposed changes. See the Contributing Guide for more. It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better. Thank you!

welcome[bot] avatar Sep 26 '24 11:09 welcome[bot]

This is using the sliding_window_view trick under the hood, which composes badly with anything that does a memory copy (like weighted in your example)

https://github.com/dask/dask/blob/d45ea380eb55feac74e8146e8ff7c6261e93b9d7/dask/array/overlap.py#L808

We actually use this approach for .rolling.mean but are clever about handling memory copies under the hood (https://github.com/pydata/xarray/pull/4915).

I'm not sure what the right solution here is.

  1. Perhaps dask can automatically rechunk the dimensions that are being "slided over"? We'd want the new dimensions "window" to be singly-chunked by default I think.
  2. On the xarray side, a lot of the pain stems from automatically padding with NaNs in rolling.construct. This has downstream consequences (np.nanmean uses a memory copy for example). But this is a more complex fix: https://github.com/pydata/xarray/pull/5603

PS: I chatted with @phofl about this at FOSS4G. He has some context.

dcherian avatar Sep 26 '24 15:09 dcherian

Yeah this is definitely on my todo list and @hendrikmakait and I chatted briefly about this today, there is definitely something we have to do

phofl avatar Sep 26 '24 16:09 phofl

I support the approach, but it'd be good to see the impact on ds.rolling().mean() which also uses construct but is clever about it to avoid the memory blowup.

dcherian avatar Sep 26 '24 16:09 dcherian

I also wonder if instead of using rolling().construct().weighted().mean() there should just be something like rolling().weighted().mean() or rolling().mean(weights=...). From what I understand, the quadratic explosion of the shape and the chunks is not inherent to this computation but we could also solve it akin to a map_overlap computation.

hendrikmakait avatar Sep 27 '24 15:09 hendrikmakait

Yes, https://github.com/pydata/xarray/issues/3937, but we've struggled to move on that.

construct is a pretty useful escape hatch for custom workloads, so we should optimize for it behaving sanely.

dcherian avatar Sep 27 '24 15:09 dcherian