xarray icon indicating copy to clipboard operation
xarray copied to clipboard

`interp` performance with chunked dimensions

Open slevang opened this issue 1 year ago • 9 comments

What is your issue?

I'm trying to perform 2D interpolation on a large 3D array that is heavily chunked along the interpolation dimensions and not the third dimension. The application could be extracting a timeseries from a reanalysis dataset chunked in space but not time, to compare to observed station data with more precise coordinates.

I use the advanced interpolation method as described in the documentation, with the interpolation coordinates specified by DataArray's with a shared dimension like so:

%load_ext memory_profiler
import numpy as np
import dask.array as da
import xarray as xr

# Synthetic dataset chunked in the two interpolation dimensions
nt = 40000
nx = 200
ny = 200
ds = xr.Dataset(
    data_vars = {
        'foo':(
            ('t', 'x', 'y'), 
            da.random.random(size=(nt, nx, ny), chunks=(-1, 10, 10))),
    },
    coords = {
        't': np.linspace(0, 1, nt),
        'x': np.linspace(0, 1, nx),
        'y': np.linspace(0, 1, ny),
    }
)

# Interpolate to some random 2D locations
ni = 10
xx = xr.DataArray(np.random.random(ni), dims='z', name='x')
yy = xr.DataArray(np.random.random(ni), dims='z', name='y')
interpolated = ds.foo.interp(x=xx, y=yy)
%memit interpolated.compute()

With just 10 interpolation points, this example calculation uses about 1.5 * ds.nbytes of memory, and saturates around 2 * ds.nbytes by about 100 interpolation points.

This could definitely work better, as each interpolated point usually only requires a single chunk of the input dataset, and at most 4 if it is right on the corner of a chunk. For example we can instead do it in a loop and get very reasonable memory usage, but this isn't very scalable:

interpolated = []
for n in range(ni):
    interpolated.append(ds.foo.interp(x=xx.isel(z=n), y=yy.isel(z=n)))
interpolated = xr.concat(interpolated, dim='z')
%memit interpolated.compute()

I tried adding a .chunk({'z':1}) to the interpolation coordinates but this doesn't help. We can also do .sel(x=xx, y=yy, method='nearest') with very good performance.

Any tips to make this calculation work better with existing options, or otherwise ways we might improve the interp method to handle this case? Given the performance behavior I'm guessing we may be doing sequntial interpolation for the dimensions, basically an interp1d call for all the xx points and from there another to the yy points, which for even a small number of points would require nearly all chunks to be loaded in. But I haven't explored the code enough yet to understand the details.

slevang avatar Jul 17 '22 14:07 slevang

Given the performance behavior I'm guessing we may be doing sequntial interpolation for the dimensions, basically an interp1d call for all the xx points and from there another to the yy points, which for even a small number of points would require nearly all chunks to be loaded in.

Yeah I think this is right.

You could check if it was better before https://github.com/pydata/xarray/pull/4155 (if it worked that is)

cc @pums974 @Illviljan

dcherian avatar Jul 18 '22 16:07 dcherian

Interpolating on chunked dimensions doesn't work at all prior to #4155. The changes in #4069 are also relevant.

slevang avatar Jul 18 '22 16:07 slevang

You are right about the behavior of the code. I don't see any way to enhance that in the general case.

Maybe, in your case, rechunking before interpolating might be a good idea

pums974 avatar Jul 18 '22 16:07 pums974

The chunking structure on disk is pretty instrumental to my application, which requires fast retrievals of full slices in the time dimension. The loop option in my first post only takes about 10 seconds with ni=1000 which is fine for my use case, so I'll probably go with that for now. It would be interesting to dig deeper though and see if there is a way to handle this better in the interp logic.

slevang avatar Jul 18 '22 21:07 slevang

The current code also has the unfortunate side-effect of merging all chunks too.

I think we should instead think of generating a dask array of weights and then using xr.dot

dcherian avatar Jul 25 '22 16:07 dcherian

The current code also has the unfortunate side-effect of merging all chunks too

Don't really know what I'm talking about here, but it looks to me like the current dask-interpolation routine uses blockwise. That is, it's trying to simply map a function over each chunk in the array. To get the chunks into a structure where this is correct to do, you have to first merge all the chunks along the interpolation axis.

I would have expected interpolation to use map_overlap. You'd add some padding to each chunk, map the interpolation over each chunk (without combining them), then trim off the extra. By using overlap, you don't need to combine all the chunks into one big array first, so the operation can actually be parallel.

FYI, fixing this would probably be a big deal to geospatial people—then you could do array reprojection without GDAL! Unfortunately not something I have time to work on right now, but perhaps someone else would be interested?

gjoseph92 avatar Nov 16 '22 17:11 gjoseph92

The challenge is you could be interping to an unordered set of locations.

So perhaps we can sort the input locations, do the interp with map_overlap, then argsort the result back to expected order.

dcherian avatar Nov 16 '22 17:11 dcherian

Linking the dask issue: https://github.com/dask/dask/issues/6474

dcherian avatar Nov 16 '22 18:11 dcherian

Hi all, Maybe this is of interest here: With @ameliefroessl, we might have managed to reduce the memory usage for the case of a regular grid (while interp() assumes rectilinear), see: https://github.com/GlacioHack/geoutils/pull/537 with code here: https://github.com/rhugonnet/geoutils/blob/add_delayed_raster_functions/geoutils/raster/delayed.py#L242.

See a comparison of RAM/execution time here: https://github.com/GlacioHack/xdem/discussions/501#discussioncomment-9171932. The RAM usage is also checked automatically in our tests and doesn't seem to exceed what we expect :slightly_smiling_face:

Using @GenevieveBuckley's very nice blogpost on ragged output (https://blog.dask.org/2021/07/02/ragged-output), we tested both map_overlap(drop_axis=) and delayed and found that the latter really performs better to minimize memory usage.

Unfortunately the implementation is not generic for Xarray, having a regular or equal grid along the interpolated dimensions is only a specific case here. So I guess the question is: Is it common enough that maybe it could be interesting to implement that functionality directly in Xarray if the interpolated dimensions are detected to be regular? Otherwise, for raster data, it will be soon become available through our Xarray accessor #8041 :wink:

rhugonnet avatar Apr 26 '24 21:04 rhugonnet