tf-keras icon indicating copy to clipboard operation
tf-keras copied to clipboard

Fitting with generators and sample weights crashes when batch size varies between steps with 3D output

Open Zahlii opened this issue 2 years ago • 1 comments

Copy of https://github.com/tensorflow/tensorflow/issues/57158

When training a model with a generator (from dataset.from_tensor_slices) and including sample weights where the batch size varies, an error is thrown.

Note: This test for us previously worked as expected in TF<2.6!

Code also available here: https://colab.research.google.com/gist/sushreebarsa/f05f31c10c6155387aedd5951091979b/57158.ipynb#scrollTo=gkT5Zc0M22pg

    import numpy as np
    import tensorflow as tf
    
    x = np.random.random((200, 20, 20, 3)) * 10
    y = x.dot(np.random.random((3, 3)))
    x = x.astype(np.uint8)
    
    model = tf.keras.Sequential()
    model.add(
        tf.keras.layers.Conv2D(
            filters=3,
            kernel_size=3,
            activation="relu",
            input_shape=(20, 20, 3),
            padding="same",
        )
    )
    model.summary()
    model.compile(optimizer="adam", loss="mse")
    
    # Works
    gen = tf.data.Dataset.from_tensor_slices((x, y)).batch(16).repeat()
    model.fit(gen, epochs=2, steps_per_epoch=10)
    
    # Works
    gen3 = tf.data.Dataset.from_tensor_slices((x, y, np.ones((200,)))).batch(20).repeat()
    model.fit(gen3, epochs=2, steps_per_epoch=10)
    
    # Doesnt work
    gen2 = tf.data.Dataset.from_tensor_slices((x, y, np.ones((200,)))).batch(16).repeat()
    model.fit(gen2, epochs=2, steps_per_epoch=10)

Zahlii avatar Aug 18 '22 07:08 Zahlii