xbatcher
xbatcher copied to clipboard
Need to prepend a size 1 batch dimension to arrays returned from batch[]
Otherwise, the user has to do this on their own to match the Keras interfaces.
@cmdupuis3 - I could use some more context here. Can you share an example that starts from a small DataArray (or Dataset) and produces the expected/unexpected output?
I think you could actually run the code here with vanilla xbatcher to see the problem.
A really useful way to move this discussion forward would be via a minimal reproducible example - a small, self contained example of code that does not rely on any external datasets. If you follow that link, you'll find some great advice about how to come up with such an example.
Because there are several different people working on xbatcher, we rely on such examples to reason about proposed changes. If we think this feature is important to have, we need to demonstrate that via arguments and examples.
I understand that, but I don't think the example I posted can be reduced much more than it is. I've stripped out as much as I can and updated it, but it's substantively the same.
With the changes to the example in #38, I think the reproducing example here has diverged slightly. This should demonstrate the dimension mismatch with the current xbatcher code.
I think there's something strange about the way the target is passed to the Keras model. Without the extra 'var' dimension, it seems like Keras is interpreting it as a scalar or something.
import numpy as np
import xarray as xr
from IPython.display import clear_output
import tensorflow as tf
import gc
import xbatcher as xb
Z = xr.DataArray(np.random.rand(640, 640), dims=['nlon', 'nlat'], name="Z")
t1 = xr.DataArray(np.random.rand(640, 640), dims=['nlon', 'nlat'], name="t1")
ds = xr.Dataset({'Z':Z, 't1':t1})
def train(ds, conv_dims_2D = [20,20], nfilters=80):
nlons = conv_dims_2D[0]
nlats = conv_dims_2D[1]
bgen = xb.BatchGenerator(
ds,
{'nlon':nlons, 'nlat':nlats},
{'nlon':int(nlons/2), 'nlat':int(nlats/2)}
)
input_stencil_2D = tf.keras.Input(shape=tuple(conv_dims_2D) + (1,))
conv_layer_2D = tf.keras.layers.Conv2D(nfilters, conv_dims_2D)(input_stencil_2D)
reshape_layer_2D = tf.keras.layers.Reshape((nfilters,))(conv_layer_2D)
output_layer = tf.keras.layers.Dense(1)(reshape_layer_2D)
model = tf.keras.Model(inputs=[input_stencil_2D], outputs=output_layer)
model.compile(loss='mae', optimizer='Adam', metrics=['mae', 'mse', 'accuracy'])
for batch in bgen:
batch_stencil_2D = batch['Z'].expand_dims('var', 2)
batch_target = batch['t1'].expand_dims('var', 2).isel(nlat=int(nlats/2), nlon=int(nlons/2))
model.fit([batch_stencil_2D],
batch_target,
batch_size=32, epochs=2, verbose=0)
clear_output(wait=True)
return model
train(ds)
It appears that this issue only arises when all dimensions are input dimensions. Otherwise, there will be a non-trivial batch dimension that works as initially expected.