keras
keras copied to clipboard
`PyDataset` of exactly correct size breaks when used in conjunction with exactly correct `validation_steps`
See title. For instance, creating a PyDataset of size batch_size * num_batches and then running model.fit(..., validation_data=dataset, batch_size=batch_size, validation_steps=num_batches) breaks because no val_logs are created in model.evaluate() (the value is None, which raises an error the validation loop of model.fit() where it is assumed to be a dict).
Minimal Example for torch backend:
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
class MyModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(10)
def call(self, x):
return self.dense(x)
class MyDataset(keras.utils.PyDataset):
def __init__(self, batches, batch_size, **kwargs):
super().__init__(**kwargs)
self.batches = batches
self.batch_size = batch_size
self.x = keras.random.normal((self.batches * self.batch_size, 2))
self.y = keras.random.normal((self.batches * self.batch_size, 10))
def __len__(self):
return self.batches
def __getitem__(self, item):
start = item * self.batch_size
stop = start + self.batch_size
item = slice(start, stop)
return self.x[item], self.y[item]
def test_validation_steps():
model = MyModel()
batches = 4
batch_size = 4
train_data = MyDataset(batches=batches, batch_size=batch_size)
validation_data = MyDataset(batches=batches, batch_size=batch_size)
model.compile(optimizer="AdamW", loss="mse")
# works fine
model.fit(train_data, epochs=2, validation_data=validation_data, steps_per_epoch=batches)
# works fine
model.fit(train_data, epochs=2, validation_data=validation_data, steps_per_epoch=batches, validation_steps=batches - 1)
# doesn't work
model.fit(train_data, epochs=2, validation_data=validation_data, steps_per_epoch=batches, validation_steps=batches)
This workaround makes the code run, but incorrectly displays zero loss:
class MyModel:
...
def evaluate(self, *args, **kwargs):
val_logs = super().evaluate(*args, **kwargs)
if val_logs is None:
val_logs = {}
return val_logs
When you are using steps_per_epoch or validation_steps, you keep drawing from the same generator across different epochs. At epoch 2, you have no data since you already consumed all the data during epoch 1, hence the failure.
Two options for you would be:
- Don't pass
validation_stepsif you want to use a new generator at every epoch. - Modify your generator so that it can return
steps * epochsbatches instead, and pass an exactvalidation_stepsvalue.
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.