xarray icon indicating copy to clipboard operation
xarray copied to clipboard

Flox based groupby operations don't support `dtype` in mean method

Open dcherian opened this issue 1 year ago • 3 comments

Discussed in https://github.com/pydata/xarray/discussions/6901

Originally posted by tasansal August 9, 2022 We have been using the new groupby logic with Flox and numpy_groupies; however, when we run the following, the dtype is not recognized as a valid argument.

This breaks API compatibility for cases where you may not have the acceleration libraries installed.

Not sure if this has to be upstream in

In addition to base Xarray we have the following extras installed: Flox numpy_groupies Bottleneck

We do this because our data is float32 but we want the accumulator in mean to be float64 for accuracy. One solution is to cast the variable to float64 before mean, which may cause a copy and spike in memory usage. When Flox and numpy_groupies are not installed, it works as expected.

We are working with multi-dimensional time-series of weather forecast models.

da = xr.load_mfdataset(...)
da.groupby("time.month").mean(dtype='float64').compute()

Here is the end of the traceback and it appears it is on Flox.

  File "/home/altay_sansal_tgs_com/miniconda3/envs/wind-data-mos/lib/python3.10/site-packages/flox/core.py", line 786, in _aggregate
    return _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
  File "/home/altay_sansal_tgs_com/miniconda3/envs/wind-data-mos/lib/python3.10/site-packages/flox/core.py", line 747, in _finalize_results
    finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)
TypeError: <lambda>() got an unexpected keyword argument 'dtype'

What is the best way to handle this, maybe fix it in Flox?

dcherian avatar Aug 09 '22 16:08 dcherian

Yeah I think we need to fix this in flox.

Can you come up with a simple test case that checks that the accumulation is done properly?

dcherian avatar Aug 09 '22 16:08 dcherian

It's not crashing for me, but the dtype is not the same when switching flox on/off:

ds = xr.tutorial.load_dataset("air_temperature")
assert ds.air.dtype == np.float32

for use_flox in (False, True):
    with xr.set_options(use_flox=use_flox):
        ds_mean = ds.groupby("time.month").mean(dtype="float64").compute()
        actual = ds_mean.air.dtype
        expected = np.float64
        print(f"{use_flox=}, {actual=}, {expected=}")
        assert actual == expected

# use_flox=False, actual=dtype('float64'), expected=<class 'numpy.float64'>
# use_flox=True, actual=dtype('float32'), expected=<class 'numpy.float64'>

xr.show_versions()

INSTALLED VERSIONS

commit: None python: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:50:36) [MSC v.1929 64 bit (AMD64)] python-bits: 64 OS: Windows OS-release: 10 machine: AMD64 processor: Intel64 Family 6 Model 58 Stepping 9, GenuineIntel byteorder: little LC_ALL: None LANG: en libhdf5: 1.12.1 libnetcdf: 4.8.1

xarray: 0.16.3.dev99+gc19467fb pandas: 1.4.2 numpy: 1.22.4 scipy: 1.8.1 netCDF4: 1.5.8 pydap: installed h5netcdf: 1.0.0 h5py: 3.6.0 Nio: None zarr: 2.11.3 cftime: 1.6.0 nc_time_axis: 1.4.1 PseudoNetCDF: 3.2.2 rasterio: 1.2.10 cfgrib: None iris: 3.2.1 bottleneck: 1.3.4 dask: 2022.6.0 distributed: 2022.6.0 matplotlib: 3.5.2 cartopy: 0.20.2 seaborn: 0.11.2 numbagg: 0.2.1 fsspec: 2022.7.1 cupy: None pint: 0.19.2 sparse: 0.13.0 flox: 0.5.5 numpy_groupies: 0.9.17 setuptools: 62.5.0 pip: 22.1.2 conda: None pytest: 7.1.2 IPython: 7.33.0 sphinx: 5.1.1

Illviljan avatar Aug 09 '22 17:08 Illviljan

Added a synthetic test case for various configurations in xarray-contrib/flox#131

tasansal avatar Aug 10 '22 15:08 tasansal