xbatcher
xbatcher copied to clipboard
Dimension order should be set by input_dims
What is your issue?
In most cases, the batch generator will permute the dimension order to agree with the order specified in input_dims. Here is an example:
>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher
>>> shape = (10, 50, 100, 200)
>>> ds = xr.Dataset(
... {
... "foo": (["time", "y", "x", "z"], np.random.rand(*shape)),
... "bar": (["time", "y", "x", "z"], np.random.randint(0, 10, shape)),
... },
... {
... "x": (["x"], np.arange(shape[-2])),
... "y": (["y"], np.arange(shape[-3])),
... },
... )
>>> print(ds)
<xarray.Dataset>
Dimensions: (time: 10, y: 50, x: 100, z: 200)
Coordinates:
* x (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
* y (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49
Dimensions without coordinates: time, z
Data variables:
foo (time, y, x, z) float64 0.6615 0.04028 0.8633 ... 0.4632 0.6561
bar (time, y, x, z) int64 8 0 9 4 8 9 2 6 7 7 5 ... 3 2 0 7 2 3 2 1 3 6
>>> print(ds['foo'].shape)
(10, 50, 100, 200)
>>> bg = xbatcher.BatchGenerator(ds, input_dims={'x': 10, 'y': 5})
>>> print(bg[0])
<xarray.Dataset>
Dimensions: (y: 5, x: 10, sample: 2000)
Coordinates:
* x (x) int64 0 1 2 3 4 5 6 7 8 9
* y (y) int64 0 1 2 3 4
* sample (sample) object MultiIndex
* time (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 9 9 9 9 9 9 9 9 9 9 9 9
* z (sample) int64 0 1 2 3 4 5 6 7 ... 192 193 194 195 196 197 198 199
Data variables:
foo (sample, x, y) float64 0.6615 0.8259 0.09629 ... 0.2105 0.09571
bar (sample, x, y) int64 8 4 0 6 0 4 4 0 5 4 5 ... 2 3 8 3 4 1 6 1 9 4
>>> print(bg[0]['foo'].shape)
(2000, 10, 5)
In at least one case, the original dimension order is retained:
>>> import numpy as np
>>> import xarray as xr
>>> import xbatcher
>>> shape = (10, 50, 100)
>>> ds = xr.Dataset(
... {
... "foo": (["time", "y", "x"], np.random.rand(*shape)),
... "bar": (["time", "y", "x"], np.random.randint(0, 10, shape)),
... },
... {
... "x": (["x"], np.arange(shape[-1])),
... "y": (["y"], np.arange(shape[-2])),
... },
... )
# Original dimensions permuted
>>> bg = xbatcher.BatchGenerator(
... ds,
... input_dims={"x": 5, "y": 10},
... batch_dims={"time": 2},
... concat_input_dims=True,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions: (y_input: 10, x_input: 5, sample: 1000)
Coordinates:
x (sample, x_input) int64 0 1 2 3 4 0 1 ... 98 99 95 96 97 98 99
y (sample, y_input) int64 0 1 2 3 4 5 6 ... 43 44 45 46 47 48 49
* sample (sample) object MultiIndex
* input_batch (sample) int64 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
* time (sample) int64 0 1 2 3 4 5 6 7 8 9 0 ... 9 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: y_input, x_input
Data variables:
foo (sample, x_input, y_input) float64 0.3198 0.3109 ... 0.5785
bar (sample, x_input, y_input) int64 1 8 5 6 9 8 7 ... 6 0 9 4 8 5
>>> print(bg[0]['foo'].shape)
(1000, 5, 10)
# Original dimension order retained
>>> bg = xbatcher.BatchGenerator(
... ds,
... input_dims={"x": 5, "y": 10},
... batch_dims={"time": 2},
... concat_input_dims=False,
... )
>>> print(bg[0])
<xarray.Dataset>
Dimensions: (time: 10, y: 10, x: 5)
Coordinates:
* x (x) int64 0 1 2 3 4
* y (y) int64 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: time
Data variables:
foo (time, y, x) float64 0.3198 0.5306 0.3465 ... 0.7873 0.5106 0.9177
bar (time, y, x) int64 1 0 2 6 5 8 0 1 2 0 5 ... 1 2 0 2 0 7 5 6 4 8 3
>>> print(bg[0]['foo'].shape)
(10, 10, 5)
We should document the intended behavior for ordering dimensions and test that the shape is consistent. I would have expected that the original dimension would be retained, in contrast to the most common behavior of the batch generator. @jhamman can you provide insight into the original intended behavior?
I will treat the edge case in which the output dimension order does not agree with the order specified by input_dims as a bug and submit a fix.