keras-io
keras-io copied to clipboard
Fix cyclegan example so that model can be saved
Currently when the cyclegan.py example is run the following error occurs when the model is being saved in the ModelCheckpoint callback.
Traceback (most recent call last):
File "cyclegan.py", line 618, in <module>
disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.8/dist-packages/keras/saving/saving_utils.py", line 93, in raise_model_input_error
raise ValueError(
ValueError: Model <__main__.CycleGan object at 0x7f960c063a30> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.
2022-08-03 17:53:25.489850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
This PR fixes the issue by calling modifying the call function of the CycleGan and calling it during the training_step.
I am not sure of what the best signature for the call function is. Currently, I wrote it so that a "call()" takes a list of tensors and depending on which element of the list is populated, the discriminator or the generator is called.
Any more idiomatic solution is greatly appreciated.
I noticed the following warnings with this PR:
1067/1067 [==============================] - ETA: 0s - G_loss: 4.4838 - F_loss: 4.0567 - D_X_loss: 0.1828 - D_Y_loss: 0
.1233
WARNING:absl:Found untraced functions such as cycle_gan_layer_call_fn, cycle_gan_layer_call_and_return_conditional
se functions will not be directly callable after loading.
2022-08-03 18:13:58.801990: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully
read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents o
f the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).rep
eat()`. You should use `dataset.take(k).cache().repeat()` instead.
and when loading the weights:
2022-08-03 18:16:58.836205: W tensorflow/core/util/tensor_slice_reader.cc:96] Could not open ./model_checkpoints/cyclegan_ch
eckpoints.001: FAILED_PRECONDITION: model_checkpoints/cyclegan_checkpoints.001; Is a directory: perhaps your file is in a di
fferent file format and you need to use a different restore operator?
Weights loaded successfully
The predicted and generated images look reasonable though.
This PR is related to: https://github.com/keras-team/keras-io/pull/925
@fchollet / Keras team: Would someone kindly assign a reviewer / provide feedback to this PR? Thank a lot!
Thanks for the PR.
The changes to call() look good, but the changes to the training loop make it less readable IMO, please avoid.
If you want to reload a saved model, please use model.save_weights()/load_weights() rather than save(). You do not need to store the model's architecture in serialized format since you have access to the model's code. Using save() will try to trace the entire model.
To get the ModelCheckpoint callback to use save_weights(), just pass to it the save_weights_only=True argument.
Thank you @fchollet ! The fix was way more simple than I thought :)
@balvisio Thanks for the PR. It looks like this has been fixed, so I'll close the request. Thanks!