xarrayutils icon indicating copy to clipboard operation
xarrayutils copied to clipboard

Improved performance for 'remove_bottom_values`

Open jbusecke opened this issue 2 years ago • 0 comments

Recently recoded the remove_bottom_values function using numba:

from numba import float64, guvectorize
import numpy as np
import xarray as xr

@guvectorize(
    [
        (float64[:], float64[:]),
    ],
    "(n)->(n)",
    nopython=True,
)
def _remove_last_value(data, output):
    # initialize output
    output[:] = data[:]
    for i in range(len(data)-1):
        if np.isnan(output[i+1]):
            output[i] = np.nan
    # take care of boundaries
    if not np.isnan(output[-1]):
        output[-1] = np.nan

def remove_bottom_values_numba(da, dim='lev'):
    
    out = xr.apply_ufunc(
        _remove_last_value,
        da,
        input_core_dims=[[dim]],
        output_core_dims=[[dim]],
        dask="parallelized",
        output_dtypes=[da.dtype],
    )
    return out

def remove_bottom_values_recoded(ds, dim="lev", fill_val=-1e10):
    """Remove the deepest values that are not nan along the dimension `dim`"""
    # for now assume that values of `dim` increase along the dimension
    if ds[dim][0] > ds[dim][-1]:
        raise ValueError(
            f"It seems like `{dim}` has decreasing values. This is not supported yet. Please sort before."
        )
    else:
        ds_masked = xr.Dataset({va:remove_bottom_values_numba(ds[va]) for va in ds.data_vars})
        ds_masked = ds_masked.transpose(*tuple([di for di in ds.dims if di in ds_masked]))
        ds_masked = ds_masked.assign_coords({co:ds[co].transpose(*[di for di in ds.dims if di in ds[co]]) for co in ds.coords})
        ds_masked.attrs = ds.attrs
        return ds_masked

I am planning on implementing this here at some point. It might also be nice to generalize this to optionally keep only the bottom, and maybe not just leave one value, but an arbitrary amount.

jbusecke avatar Oct 19 '21 16:10 jbusecke