keras-io icon indicating copy to clipboard operation
keras-io copied to clipboard

Fix cyclegan example so that model can be saved

Open balvisio opened this issue 3 years ago • 3 comments
trafficstars

Currently when the cyclegan.py example is run the following error occurs when the model is being saved in the ModelCheckpoint callback.

Traceback (most recent call last):
  File "cyclegan.py", line 618, in <module>
    disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
  File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.8/dist-packages/keras/saving/saving_utils.py", line 93, in raise_model_input_error
    raise ValueError(
ValueError: Model <__main__.CycleGan object at 0x7f960c063a30> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.
2022-08-03 17:53:25.489850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

This PR fixes the issue by calling modifying the call function of the CycleGan and calling it during the training_step.   I am not sure of what the best signature for the call function is. Currently, I wrote it so that a "call()" takes a list of tensors and depending on which element of the list is populated, the discriminator or the generator is called.   Any more idiomatic solution is greatly appreciated.

I noticed the following warnings with this PR:

1067/1067 [==============================] - ETA: 0s - G_loss: 4.4838 - F_loss: 4.0567 - D_X_loss: 0.1828 - D_Y_loss: 0
.1233
WARNING:absl:Found untraced functions such as cycle_gan_layer_call_fn, cycle_gan_layer_call_and_return_conditional
se functions will not be directly callable after loading.
2022-08-03 18:13:58.801990: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully
 read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents o
f the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).rep
eat()`. You should use `dataset.take(k).cache().repeat()` instead.

and when loading the weights:

2022-08-03 18:16:58.836205: W tensorflow/core/util/tensor_slice_reader.cc:96] Could not open ./model_checkpoints/cyclegan_ch
eckpoints.001: FAILED_PRECONDITION: model_checkpoints/cyclegan_checkpoints.001; Is a directory: perhaps your file is in a di
fferent file format and you need to use a different restore operator?
Weights loaded successfully

The predicted and generated images look reasonable though.

This PR is related to: https://github.com/keras-team/keras-io/pull/925

balvisio avatar Aug 03 '22 18:08 balvisio

@fchollet / Keras team: Would someone kindly assign a reviewer / provide feedback to this PR? Thank a lot!

balvisio avatar Sep 20 '22 03:09 balvisio

Thanks for the PR.

The changes to call() look good, but the changes to the training loop make it less readable IMO, please avoid.

If you want to reload a saved model, please use model.save_weights()/load_weights() rather than save(). You do not need to store the model's architecture in serialized format since you have access to the model's code. Using save() will try to trace the entire model.

To get the ModelCheckpoint callback to use save_weights(), just pass to it the save_weights_only=True argument.

fchollet avatar Oct 02 '22 21:10 fchollet

Thank you @fchollet ! The fix was way more simple than I thought :)

balvisio avatar Oct 10 '22 21:10 balvisio

@balvisio Thanks for the PR. It looks like this has been fixed, so I'll close the request. Thanks!

pcoet avatar Aug 16 '23 20:08 pcoet