xarray
xarray copied to clipboard
xarray.DataArray.str.cat() doesn't work on chunked data
What happened?
I was trying to concatenate some DataArrays of strings, and it kept just returning the first DataArray without any changes.
What did you expect to happen?
I was expecting it to just provide the strings, concatenated together with the spearator between them.
Minimal Complete Verifiable Example
da = xr.DataArray(
np.zeros((2, 2)).astype(str),
coords={'x':np.arange(2), 'y': np.arange(2)},
dims=['x', 'y'])
dac = da.chunk()
print((da == dac).values.all())
print((da.str.cat(da, sep='--') == dac.str.cat(dac, sep='--')).values.all())
print((da.str.cat(da, sep='--') == dac.compute().str.cat(dac.compute(), sep='--')).values.all())
>>> True
>>> False
>>> True
MVCE confirmation
- [x] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- [x] Complete example — the example is self-contained, including all data and the text of any traceback.
- [x] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- [x] New issue — a search of GitHub Issues suggests this is not a duplicate.
Relevant log output
No response
Anything else we need to know?
No response
Environment
INSTALLED VERSIONS
commit: None python: 3.8.10 (default, Mar 15 2022, 12:22:08) [GCC 9.4.0] python-bits: 64 OS: Linux OS-release: 5.11.0-27-generic machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: C.UTF-8 LOCALE: ('en_US', 'UTF-8') libhdf5: 1.12.2 libnetcdf: 4.9.0
xarray: 2022.3.0 pandas: 1.4.2 numpy: 1.22.4 scipy: 1.8.1 netCDF4: 1.6.0 pydap: None h5netcdf: 1.0.1 h5py: 3.6.0 Nio: None zarr: 2.11.3 cftime: 1.6.1 nc_time_axis: None PseudoNetCDF: None rasterio: 1.2.10 cfgrib: None iris: None bottleneck: None dask: 2022.7.0 distributed: None matplotlib: 3.5.2 cartopy: None seaborn: None numbagg: None fsspec: 2022.5.0 cupy: None pint: None sparse: None setuptools: 62.3.2 pip: 22.1.2 conda: None pytest: None IPython: 8.3.0 sphinx: None
Thanks for your report. I think the issue is that dask cannot correctly infer the dtype of the result - or at least not it's length (maybe because it does not do value-based casting? not sure).
As a workaround you could do an intermediate cast to an object:
dac.astype(object).str.cat(dac, sep='--').astype("U").compute()
Thanks for the workaround @mathause!
Is there a benefit to your approach, rather than calling compute()
on each DataArray? It seems like calling compute()
twice is faster for the MVCE example (but maybe it won't scale that way).
But either way, it would be nice if the function threw a warning/error for handling dask arrays!
Yes good point - just calling compute may be the better solution.