xarrayutils
xarrayutils copied to clipboard
Improved performance for 'remove_bottom_values`
Recently recoded the remove_bottom_values
function using numba:
from numba import float64, guvectorize
import numpy as np
import xarray as xr
@guvectorize(
[
(float64[:], float64[:]),
],
"(n)->(n)",
nopython=True,
)
def _remove_last_value(data, output):
# initialize output
output[:] = data[:]
for i in range(len(data)-1):
if np.isnan(output[i+1]):
output[i] = np.nan
# take care of boundaries
if not np.isnan(output[-1]):
output[-1] = np.nan
def remove_bottom_values_numba(da, dim='lev'):
out = xr.apply_ufunc(
_remove_last_value,
da,
input_core_dims=[[dim]],
output_core_dims=[[dim]],
dask="parallelized",
output_dtypes=[da.dtype],
)
return out
def remove_bottom_values_recoded(ds, dim="lev", fill_val=-1e10):
"""Remove the deepest values that are not nan along the dimension `dim`"""
# for now assume that values of `dim` increase along the dimension
if ds[dim][0] > ds[dim][-1]:
raise ValueError(
f"It seems like `{dim}` has decreasing values. This is not supported yet. Please sort before."
)
else:
ds_masked = xr.Dataset({va:remove_bottom_values_numba(ds[va]) for va in ds.data_vars})
ds_masked = ds_masked.transpose(*tuple([di for di in ds.dims if di in ds_masked]))
ds_masked = ds_masked.assign_coords({co:ds[co].transpose(*[di for di in ds.dims if di in ds[co]]) for co in ds.coords})
ds_masked.attrs = ds.attrs
return ds_masked
I am planning on implementing this here at some point. It might also be nice to generalize this to optionally keep only the bottom, and maybe not just leave one value, but an arbitrary amount.