tf-raft icon indicating copy to clipboard operation
tf-raft copied to clipboard

Bug when trying to save full model

Open kmonachopoulos opened this issue 2 years ago • 0 comments

Hi,

I managed to train and validate the model on a small subset of the dataset and trying to save the full model instead of only the weights through the callback function.

According to https://keras.io/api/callbacks/model_checkpoint/ in order to do that we have to change the save_weights_only to False and force dumping the whole model.

However, when I change the value I am getting the following error

TypeError: in user code:

    File "/Users/xxxxxx/miniforge3/lib/python3.9/site-packages/keras/saving/saving_utils.py", line 138, in _wrapped_model  *
        outputs = model(*args, **kwargs)
    File "/Users/xxxxxx/miniforge3/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/var/folders/lh/57ty9td15gv9gd258s8d2sg00000gn/T/__autograph_generated_filelj5z_b00.py", line 16, in tf__call
        correlation = ag__.converted_call(ag__.ld(CorrBlock), (ag__.ld(fmap1), ag__.ld(fmap2)), dict(num_levels=ag__.ld(self).corr_levels, radius=ag__.ld(self).corr_radius), fscope)
    File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/layers/corr.py", line 106, in __init__
        corr = self.correlation(fmap1, fmap2)
    File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/layers/corr.py", line 156, in correlation
        fmap1 = tf.reshape(fmap1, (batch_size, h*w, nch))

    TypeError: Exception encountered when calling layer "raft" (type RAFT).

    in user code:

        File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/model.py", line 80, in call  *
            correlation = CorrBlock(fmap1, fmap2,
        File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/layers/corr.py", line 106, in __init__  **
            corr = self.correlation(fmap1, fmap2)
        File "/Users/xxxxxx/xxxxxx/arm-use-cases/tf-raft/tf_raft/layers/corr.py", line 156, in correlation
            fmap1 = tf.reshape(fmap1, (batch_size, h*w, nch))

        TypeError: Failed to convert elements of (None, 2852, 256) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.


    Call arguments received by layer "raft" (type RAFT):
      • inputs=['tf.Tensor(shape=(None, 368, 496, 3), dtype=float32)', 'tf.Tensor(shape=(None, 368, 496, 3), dtype=float32)']
      • training=False

    2022-11-20 21:24:32.326266: W tensorflow/core/kernels/data/generator_dataset_op.cc:108] Error occurred when finalizing GeneratorDataset iterator: FAILED_PRECONDITION: Python interpreter state is not initialized. The process may be terminated.
	 [[{{node PyFunc}}]]

It seems that inside the Corblock class and in particular in the correlation member function there is a conflict with the elements datatype or something along these lines.

Is there any suggestion on this issue? Essentially, my final goal is to export a full checkpoint/protobuf or keras model to be able to convert it to tflite.

Thanks!

kmonachopoulos avatar Nov 20 '22 21:11 kmonachopoulos