tf-keras
tf-keras copied to clipboard
Fitting with generators and sample weights crashes when batch size varies between steps with 3D output
Copy of https://github.com/tensorflow/tensorflow/issues/57158
When training a model with a generator (from dataset.from_tensor_slices) and including sample weights where the batch size varies, an error is thrown.
Note: This test for us previously worked as expected in TF<2.6!
Code also available here: https://colab.research.google.com/gist/sushreebarsa/f05f31c10c6155387aedd5951091979b/57158.ipynb#scrollTo=gkT5Zc0M22pg
import numpy as np
import tensorflow as tf
x = np.random.random((200, 20, 20, 3)) * 10
y = x.dot(np.random.random((3, 3)))
x = x.astype(np.uint8)
model = tf.keras.Sequential()
model.add(
tf.keras.layers.Conv2D(
filters=3,
kernel_size=3,
activation="relu",
input_shape=(20, 20, 3),
padding="same",
)
)
model.summary()
model.compile(optimizer="adam", loss="mse")
# Works
gen = tf.data.Dataset.from_tensor_slices((x, y)).batch(16).repeat()
model.fit(gen, epochs=2, steps_per_epoch=10)
# Works
gen3 = tf.data.Dataset.from_tensor_slices((x, y, np.ones((200,)))).batch(20).repeat()
model.fit(gen3, epochs=2, steps_per_epoch=10)
# Doesnt work
gen2 = tf.data.Dataset.from_tensor_slices((x, y, np.ones((200,)))).batch(16).repeat()
model.fit(gen2, epochs=2, steps_per_epoch=10)