xarray
xarray copied to clipboard
Flox based groupby operations don't support `dtype` in mean method
Discussed in https://github.com/pydata/xarray/discussions/6901
Originally posted by tasansal August 9, 2022 We have been using the new groupby logic with Flox and numpy_groupies; however, when we run the following, the dtype is not recognized as a valid argument.
This breaks API compatibility for cases where you may not have the acceleration libraries installed.
Not sure if this has to be upstream in
In addition to base Xarray we have the following extras installed: Flox numpy_groupies Bottleneck
We do this because our data is float32
but we want the accumulator in mean to be float64
for accuracy.
One solution is to cast the variable to float64 before mean, which may cause a copy and spike in memory usage.
When Flox and numpy_groupies are not installed, it works as expected.
We are working with multi-dimensional time-series of weather forecast models.
da = xr.load_mfdataset(...)
da.groupby("time.month").mean(dtype='float64').compute()
Here is the end of the traceback and it appears it is on Flox.
File "/home/altay_sansal_tgs_com/miniconda3/envs/wind-data-mos/lib/python3.10/site-packages/flox/core.py", line 786, in _aggregate
return _finalize_results(results, agg, axis, expected_groups, fill_value, reindex)
File "/home/altay_sansal_tgs_com/miniconda3/envs/wind-data-mos/lib/python3.10/site-packages/flox/core.py", line 747, in _finalize_results
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)
TypeError: <lambda>() got an unexpected keyword argument 'dtype'
What is the best way to handle this, maybe fix it in Flox?
Yeah I think we need to fix this in flox.
Can you come up with a simple test case that checks that the accumulation is done properly?
It's not crashing for me, but the dtype is not the same when switching flox on/off:
ds = xr.tutorial.load_dataset("air_temperature")
assert ds.air.dtype == np.float32
for use_flox in (False, True):
with xr.set_options(use_flox=use_flox):
ds_mean = ds.groupby("time.month").mean(dtype="float64").compute()
actual = ds_mean.air.dtype
expected = np.float64
print(f"{use_flox=}, {actual=}, {expected=}")
assert actual == expected
# use_flox=False, actual=dtype('float64'), expected=<class 'numpy.float64'>
# use_flox=True, actual=dtype('float32'), expected=<class 'numpy.float64'>
xr.show_versions()
INSTALLED VERSIONS
commit: None python: 3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:50:36) [MSC v.1929 64 bit (AMD64)] python-bits: 64 OS: Windows OS-release: 10 machine: AMD64 processor: Intel64 Family 6 Model 58 Stepping 9, GenuineIntel byteorder: little LC_ALL: None LANG: en libhdf5: 1.12.1 libnetcdf: 4.8.1
xarray: 0.16.3.dev99+gc19467fb pandas: 1.4.2 numpy: 1.22.4 scipy: 1.8.1 netCDF4: 1.5.8 pydap: installed h5netcdf: 1.0.0 h5py: 3.6.0 Nio: None zarr: 2.11.3 cftime: 1.6.0 nc_time_axis: 1.4.1 PseudoNetCDF: 3.2.2 rasterio: 1.2.10 cfgrib: None iris: 3.2.1 bottleneck: 1.3.4 dask: 2022.6.0 distributed: 2022.6.0 matplotlib: 3.5.2 cartopy: 0.20.2 seaborn: 0.11.2 numbagg: 0.2.1 fsspec: 2022.7.1 cupy: None pint: 0.19.2 sparse: 0.13.0 flox: 0.5.5 numpy_groupies: 0.9.17 setuptools: 62.5.0 pip: 22.1.2 conda: None pytest: 7.1.2 IPython: 7.33.0 sphinx: 5.1.1
Added a synthetic test case for various configurations in xarray-contrib/flox#131