xbatcher icon indicating copy to clipboard operation
xbatcher copied to clipboard

Dimension order should be set by input_dims

Open maxrjones opened this issue 3 years ago • 2 comments

What is your issue?

In most cases, the batch generator will permute the dimension order to agree with the order specified in input_dims. Here is an example:

>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher

>>> shape = (10, 50, 100, 200)
>>> ds = xr.Dataset(
...     {
...         "foo": (["time", "y", "x", "z"], np.random.rand(*shape)),
...         "bar": (["time", "y", "x", "z"], np.random.randint(0, 10, shape)),
...     },
...     {
...         "x": (["x"], np.arange(shape[-2])),
...         "y": (["y"], np.arange(shape[-3])),
...     },
... )
>>> print(ds)
<xarray.Dataset>
Dimensions:  (time: 10, y: 50, x: 100, z: 200)
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49
Dimensions without coordinates: time, z
Data variables:
    foo      (time, y, x, z) float64 0.6615 0.04028 0.8633 ... 0.4632 0.6561
    bar      (time, y, x, z) int64 8 0 9 4 8 9 2 6 7 7 5 ... 3 2 0 7 2 3 2 1 3 6
>>> print(ds['foo'].shape)
(10, 50, 100, 200)
>>> bg = xbatcher.BatchGenerator(ds, input_dims={'x': 10, 'y': 5})
>>> print(bg[0])
<xarray.Dataset>
Dimensions:  (y: 5, x: 10, sample: 2000)
Coordinates:
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9
  * y        (y) int64 0 1 2 3 4
  * sample   (sample) object MultiIndex
  * time     (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 9 9 9 9 9 9 9 9 9 9 9 9
  * z        (sample) int64 0 1 2 3 4 5 6 7 ... 192 193 194 195 196 197 198 199
Data variables:
    foo      (sample, x, y) float64 0.6615 0.8259 0.09629 ... 0.2105 0.09571
    bar      (sample, x, y) int64 8 4 0 6 0 4 4 0 5 4 5 ... 2 3 8 3 4 1 6 1 9 4
>>> print(bg[0]['foo'].shape)
(2000, 10, 5)

In at least one case, the original dimension order is retained:

>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher
>>> shape = (10, 50, 100)
>>> ds = xr.Dataset(
...     {
...         "foo": (["time", "y", "x"], np.random.rand(*shape)),
...         "bar": (["time", "y", "x"], np.random.randint(0, 10, shape)),
...     },
...     {
...         "x": (["x"], np.arange(shape[-1])),
...         "y": (["y"], np.arange(shape[-2])),
...     },
... )
# Original dimensions permuted
>>> bg = xbatcher.BatchGenerator(
...     ds,
...     input_dims={"x": 5, "y": 10},
...     batch_dims={"time": 2},
...     concat_input_dims=True,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions:      (y_input: 10, x_input: 5, sample: 1000)
Coordinates:
    x            (sample, x_input) int64 0 1 2 3 4 0 1 ... 98 99 95 96 97 98 99
    y            (sample, y_input) int64 0 1 2 3 4 5 6 ... 43 44 45 46 47 48 49
  * sample       (sample) object MultiIndex
  * input_batch  (sample) int64 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
  * time         (sample) int64 0 1 2 3 4 5 6 7 8 9 0 ... 9 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: y_input, x_input
Data variables:
    foo          (sample, x_input, y_input) float64 0.3198 0.3109 ... 0.5785
    bar          (sample, x_input, y_input) int64 1 8 5 6 9 8 7 ... 6 0 9 4 8 5
>>> print(bg[0]['foo'].shape)
(1000, 5, 10)
# Original dimension order retained
>>> bg = xbatcher.BatchGenerator(
...     ds,
...     input_dims={"x": 5, "y": 10},
...     batch_dims={"time": 2},
...     concat_input_dims=False,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions:  (time: 10, y: 10, x: 5)
Coordinates:
  * x        (x) int64 0 1 2 3 4
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: time
Data variables:
    foo      (time, y, x) float64 0.3198 0.5306 0.3465 ... 0.7873 0.5106 0.9177
    bar      (time, y, x) int64 1 0 2 6 5 8 0 1 2 0 5 ... 1 2 0 2 0 7 5 6 4 8 3
>>> print(bg[0]['foo'].shape)
(10, 10, 5)

We should document the intended behavior for ordering dimensions and test that the shape is consistent. I would have expected that the original dimension would be retained, in contrast to the most common behavior of the batch generator. @jhamman can you provide insight into the original intended behavior?

maxrjones avatar Nov 17 '22 22:11 maxrjones

I will treat the edge case in which the output dimension order does not agree with the order specified by input_dims as a bug and submit a fix.

maxrjones avatar Dec 02 '22 22:12 maxrjones