uxarray
uxarray copied to clipboard
Support remapping between Rectilinear grids
Discussed in https://github.com/UXARRAY/uxarray/discussions/768
Originally posted by tomvothecoder April 18, 2024 Hi again! For my upcoming E3SM demo with UXarray and xCDAT, we want to regrid E3SM data to a rectilinear grid in order to use xCDAT's spatial averager. Is there a way to do this with UXarray? There was a discussion about regridding to lat-lon here.
What we found
We found uxarray.UxDataset.nearest_neighbor_remap, but we aren't sure if it supports rectilinear grids.
Our Current Workarounds
- ncremap
- The simple_regrid function shown below
def simple_regrid(
ds: xr.Dataset,
target_var: str,
nlat: xr.DataArray,
nlon: xr.DataArray,
infill: bool = True,
) -> xr.Dataset:
"""
Nearest neighbor mapping of 2D fields from E3SM column format to a
user-defined rectilinear grid.
Parameters:
-----------
ds : xr.Dataset
Source dataset to remap
target_var : str
Name of variable to remap
nlat : xr.DataArray
Target latitude values
nlon : xr.DataArray
Target longitude values
infill : bool
Flag to infill (with extrapolation) missing values (default True)
Returns:
--------
xr.Dataset
Notes:
------
This regridding tool is intended as a simple regridding tool and is not fit
for most scientific applications, but may be useful for data quick-looks.
"""
dset_out = []
# Loop over time steps and remap one map at a time.
for i in range(len(ds.time)):
# Get data
lat = ds.lat
lon = ds.lon
data = ds[target_var].isel(time=i)
# Target grid
LON, LAT = np.meshgrid(nlon, nlat)
shp = LAT.shape
# Create a nearest-neighbor tree for the grid
tree = spatial.cKDTree(list(zip(LAT.flat, LON.flat)))
_, ind = tree.query(np.array([lat, lon]).T)
n = tree.n
# Sum of data in each grid box
X = np.bincount(ind, weights=data, minlength=n)
# Total number of matches
cnt = np.bincount(ind, weights=np.ones_like(data), minlength=n)
# Going to get divide by zero here for grid boxes with no data
with np.errstate(divide="ignore", invalid="ignore"):
grid = X / cnt
# Reshape to regular grid
grid = grid.reshape(shp)
grid = xr.DataArray(
data=grid,
dims=["lat", "lon"],
coords={"lat": nlat, "lon": nlon},
name=target_var,
)
dset_out.append(grid)
# Concatenate time steps and create xr.Dataset
ds_final = xr.concat(dset_out, dim=ds.time).to_dataset()
# Incorporate bounds from original dataset
if "time_bnds" in ds.data_vars:
ds_final["time_bnds"] = ds.time_bnds
# Add missing bounds
ds_final = ds_final.bounds.add_missing_bounds(["X", "Y", "T"])
# Infill (if desired)
if infill:
ds_final[target_var] = ds_final[target_var].interpolate_na(
dim="lon", method="nearest", fill_value="extrapolate"
)
return ds_final
python
# Define regrid targets
target_var = "TREFHT"
nlat, _ = xc.create_axis(
"lat",
np.arange(-88.75, 90, 2.5),
attrs={"axis": "Y", "units": "degrees_north"},
generate_bounds=False,
)
nlon, _ = xc.create_axis(
"lon",
np.arange(1.25, 360, 2.5),
attrs={"axis": "X", "units": "degrees_east"},
generate_bounds=False,
)
# call simple regridder
uxds_r = simple_regrid(uxds, target_var, nlat, nlon)</div>