uxarray icon indicating copy to clipboard operation
uxarray copied to clipboard

Support remapping between Rectilinear grids

Open philipc2 opened this issue 10 months ago • 0 comments

Discussed in https://github.com/UXARRAY/uxarray/discussions/768

Originally posted by tomvothecoder April 18, 2024 Hi again! For my upcoming E3SM demo with UXarray and xCDAT, we want to regrid E3SM data to a rectilinear grid in order to use xCDAT's spatial averager. Is there a way to do this with UXarray? There was a discussion about regridding to lat-lon here.

What we found

We found uxarray.UxDataset.nearest_neighbor_remap, but we aren't sure if it supports rectilinear grids.

Our Current Workarounds

  1. ncremap
  2. The simple_regrid function shown below
 def simple_regrid(
    ds: xr.Dataset,
    target_var: str,
    nlat: xr.DataArray,
    nlon: xr.DataArray,
    infill: bool = True,
) -> xr.Dataset:
    """
    Nearest neighbor mapping of 2D fields from E3SM column format to a
    user-defined rectilinear grid.

    Parameters:
    -----------
    ds : xr.Dataset
        Source dataset to remap
    target_var : str
        Name of variable to remap
    nlat : xr.DataArray
        Target latitude values
    nlon : xr.DataArray
        Target longitude values
    infill : bool
        Flag to infill (with extrapolation) missing values (default True)

    Returns:
    --------
    xr.Dataset

    Notes:
    ------
    This regridding tool is intended as a simple regridding tool and is not fit
    for most scientific applications, but may be useful for data quick-looks.
    """
    dset_out = []

    # Loop over time steps and remap one map at a time.
    for i in range(len(ds.time)):
        # Get data
        lat = ds.lat
        lon = ds.lon
        data = ds[target_var].isel(time=i)

        # Target grid
        LON, LAT = np.meshgrid(nlon, nlat)
        shp = LAT.shape

        # Create a nearest-neighbor tree for the grid
        tree = spatial.cKDTree(list(zip(LAT.flat, LON.flat)))
        _, ind = tree.query(np.array([lat, lon]).T)
        n = tree.n

        # Sum of data in each grid box
        X = np.bincount(ind, weights=data, minlength=n)
        # Total number of matches
        cnt = np.bincount(ind, weights=np.ones_like(data), minlength=n)

        # Going to get divide by zero here for grid boxes with no data
        with np.errstate(divide="ignore", invalid="ignore"):
            grid = X / cnt

        # Reshape to regular grid
        grid = grid.reshape(shp)
        grid = xr.DataArray(
            data=grid,
            dims=["lat", "lon"],
            coords={"lat": nlat, "lon": nlon},
            name=target_var,
        )
        dset_out.append(grid)

    # Concatenate time steps and create xr.Dataset
    ds_final = xr.concat(dset_out, dim=ds.time).to_dataset()

    # Incorporate bounds from original dataset
    if "time_bnds" in ds.data_vars:
        ds_final["time_bnds"] = ds.time_bnds

    # Add missing bounds
    ds_final = ds_final.bounds.add_missing_bounds(["X", "Y", "T"])

    # Infill (if desired)
    if infill:
        ds_final[target_var] = ds_final[target_var].interpolate_na(
            dim="lon", method="nearest", fill_value="extrapolate"
        )

    return ds_final
python
# Define regrid targets
target_var = "TREFHT"
nlat, _ = xc.create_axis(
    "lat",
    np.arange(-88.75, 90, 2.5),
    attrs={"axis": "Y", "units": "degrees_north"},
    generate_bounds=False,
)
nlon, _ = xc.create_axis(
    "lon",
    np.arange(1.25, 360, 2.5),
    attrs={"axis": "X", "units": "degrees_east"},
    generate_bounds=False,
)

# call simple regridder
uxds_r = simple_regrid(uxds, target_var, nlat, nlon)</div>

philipc2 avatar Apr 19 '24 16:04 philipc2