tf-raft
tf-raft copied to clipboard
Bug when trying to save full model
Hi,
I managed to train and validate the model on a small subset of the dataset and trying to save the full model instead of only the weights through the callback function.
According to https://keras.io/api/callbacks/model_checkpoint/ in order to do that we have to change the save_weights_only to False and force dumping the whole model.
However, when I change the value I am getting the following error
TypeError: in user code:
File "/Users/xxxxxx/miniforge3/lib/python3.9/site-packages/keras/saving/saving_utils.py", line 138, in _wrapped_model *
outputs = model(*args, **kwargs)
File "/Users/xxxxxx/miniforge3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler **
raise e.with_traceback(filtered_tb) from None
File "/var/folders/lh/57ty9td15gv9gd258s8d2sg00000gn/T/__autograph_generated_filelj5z_b00.py", line 16, in tf__call
correlation = ag__.converted_call(ag__.ld(CorrBlock), (ag__.ld(fmap1), ag__.ld(fmap2)), dict(num_levels=ag__.ld(self).corr_levels, radius=ag__.ld(self).corr_radius), fscope)
File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/layers/corr.py", line 106, in __init__
corr = self.correlation(fmap1, fmap2)
File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/layers/corr.py", line 156, in correlation
fmap1 = tf.reshape(fmap1, (batch_size, h*w, nch))
TypeError: Exception encountered when calling layer "raft" (type RAFT).
in user code:
File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/model.py", line 80, in call *
correlation = CorrBlock(fmap1, fmap2,
File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/layers/corr.py", line 106, in __init__ **
corr = self.correlation(fmap1, fmap2)
File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/layers/corr.py", line 156, in correlation
fmap1 = tf.reshape(fmap1, (batch_size, h*w, nch))
TypeError: Failed to convert elements of (None, 2852, 256) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
Call arguments received by layer "raft" (type RAFT):
• inputs=['tf.Tensor(shape=(None, 368, 496, 3), dtype=float32)', 'tf.Tensor(shape=(None, 368, 496, 3), dtype=float32)']
• training=False
2022-11-20 21:24:32.326266: W tensorflow/core/kernels/data/generator_dataset_op.cc:108] Error occurred when finalizing GeneratorDataset iterator: FAILED_PRECONDITION: Python interpreter state is not initialized. The process may be terminated.
[[{{node PyFunc}}]]
It seems that inside the Corblock
class and in particular in the correlation
member function there is a conflict with the elements datatype or something along these lines.
Is there any suggestion on this issue? Essentially, my final goal is to export a full checkpoint/protobuf or keras model to be able to convert it to tflite.
Thanks!