spektral icon indicating copy to clipboard operation
spektral copied to clipboard

Loading a saved model has different structure. Error to use the loaded model.

Open JaeLee18 opened this issue 2 years ago • 3 comments

Hello, I have a model like this:

l2_reg = 5e-4  # Regularization rate for l2


class MyFirstGNN(Model):

    def __init__(self, n_labels):
        super().__init__()
        self.conv1 = GCNConv(32, activation="elu",
                             kernel_regularizer=l2(l2_reg))
        self.conv2 = GCNConv(32, activation="elu",
                             kernel_regularizer=l2(l2_reg))
        self.pool = GlobalSumPool()
        self.dropout = Dropout(0.2)
        self.fc1 = Dense(512, activation="relu")
        self.dense = Dense(n_labels, 'softmax')

    def call(self, inputs):
        x, a = inputs
        out = self.conv1([x, a])
        out = self.conv2([out, a])
        out = self.dropout(out)
        out = self.pool(out)
        out = self.fc1(out)
        out = self.dense(out)

        return out

I trained this model using the following code: model.fit(tr_loader.load(), steps_per_epoch=tr_loader.steps_per_epoch, epochs=100)

After I saved this model as model.save('modelPath', save_format='tf') and then I loaded this saved model as this new_model = tf.keras.models.load_model('modelPath').

I can successfully load it. However, if I try to evaluate with the same training data by new_model.evaluate(tr_loader.load()) then I got this error:

InvalidArgumentError:  Input to reshape is a tensor with 50368 values, but the requested shape has 53184
	 [[{{node my_first_gnn_4/StatefulPartitionedCall/StatefulPartitionedCall/gcn_conv_8/Reshape_2}}]] [Op:__inference_test_function_123064]

Function call stack:
test_function

Here is the full error stack:

InvalidArgumentError                      Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_14196/3741965917.py in <module>
----> 1 new_model.evaluate(tr_loader.load())

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
    106   def _method_wrapper(self, *args, **kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self, *args, **kwargs)
    109 
    110     # Running inside `run_distribute_coordinator` already.

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\keras\engine\training.py in evaluate(self, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, return_dict)
   1377             with trace.Trace('TraceContext', graph_type='test', step_num=step):
   1378               callbacks.on_test_batch_begin(step)
-> 1379               tmp_logs = test_function(iterator)
   1380               if data_handler.should_sync:
   1381                 context.async_wait()

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    812       # In this case we have not created variables on the first call. So we can
    813       # run the first trace but we should fail if variables are created.
--> 814       results = self._stateful_fn(*args, **kwds)
    815       if self._created_variables:
    816         raise ValueError("Creating variables on a non-first call to a function"

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2827     with self._lock:
   2828       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2829     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2830 
   2831   @property

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs, cancellation_manager)
   1841       `args` and `kwargs`.
   1842     """
-> 1843     return self._call_flat(
   1844         [t for t in nest.flatten((args, kwargs), expand_composites=True)
   1845          if isinstance(t, (ops.Tensor,

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1921         and executing_eagerly):
   1922       # No tape is watching; skip to running the function.
-> 1923       return self._build_call_outputs(self._inference_function.call(
   1924           ctx, args, cancellation_manager=cancellation_manager))
   1925     forward_backward = self._select_forward_and_backward_functions(

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    543       with _InterpolateFunctionError(self):
    544         if cancellation_manager is None:
--> 545           outputs = execute.execute(
    546               str(self.signature.name),
    547               num_outputs=self._num_outputs,

~\anaconda3\envs\tf\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:

InvalidArgumentError:  Input to reshape is a tensor with 50368 values, but the requested shape has 53184
	 [[{{node my_first_gnn_4/StatefulPartitionedCall/StatefulPartitionedCall/gcn_conv_8/Reshape_2}}]] [Op:__inference_test_function_123064]

Function call stack:
test_function

JaeLee18 avatar Oct 14 '21 02:10 JaeLee18

Hi,

what dataset is this happening with? Can you provide a minimal example to reproduce the crash?

Thanks!

danielegrattarola avatar Oct 14 '21 11:10 danielegrattarola

Hi,

what dataset is this happening with? Can you provide a minimal example to reproduce the crash?

Thanks!

Hello,

It was a custom dataset with BatchLoader but instead of using load_model, I solved the issue like this:

model.save_weights('path')
newModel = MyFirstGNN(n_label)
newModel.load_weights('path')

Is there any way to save this model as hdf5 format?

JaeLee18 avatar Oct 14 '21 15:10 JaeLee18

Exporting to hdf5 should work out of the box, you just need to make sure that you pass a custom_objects dictionary when loading the model back.

danielegrattarola avatar Oct 19 '21 09:10 danielegrattarola