simclr icon indicating copy to clipboard operation
simclr copied to clipboard

Saving model after finetuning on colab tf2

Open junhuplim opened this issue 4 years ago • 0 comments

Hi all,

I have finetuned a model on my custom dataset and would like to save the model for inference purposes. I realised this wasn't addressed in the colab examples.

I tried to look around how a model is saved in run.py: https://github.com/google-research/simclr/blob/dec99a81a4ceccb0a5a893afecbc2ee18f1d76c3/tf2/run.py#L287

but got the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-11-215338bd29a6> in <module>
     34         tf.saved_model.save(saved_model, checkpoint_export_dir)
     35 
---> 36 save(model, 1)

<ipython-input-11-215338bd29a6> in save(model, global_step)
     26 def save(model, global_step):
     27     """Export as SavedModel for finetuning and inference."""
---> 28     saved_model = build_saved_model(model)
     29     #   export_dir = os.path.join(FLAGS.model_dir, 'saved_model')
     30     export_dir = model_p

<ipython-input-11-215338bd29a6> in build_saved_model(model, include_projection_head)
     18     module = SimCLRModel(model)
     19     input_spec = tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)
---> 20     module.__call__.get_concrete_function(input_spec, trainable=True)
     21     module.__call__.get_concrete_function(input_spec, trainable=False)
     22     return module

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1365       ValueError: if this object has not yet been called on concrete values.
   1366     """
-> 1367     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1368     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1369     return concrete

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1271       if self._stateful_fn is None:
   1272         initializers = []
-> 1273         self._initialize(args, kwargs, add_initializers_to=initializers)
   1274         self._initialize_uninitialized_variables(initializers)
   1275 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    762     self._concrete_stateful_fn = (
    763         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 764             *args, **kwds))
    765 
    766     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3048       args, kwargs = None, None
   3049     with self._lock:
-> 3050       graph_function, _ = self._maybe_define_function(args, kwargs)
   3051     return graph_function
   3052 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3442 
   3443           self._function_cache.missed.add(call_context_key)
-> 3444           graph_function = self._create_graph_function(args, kwargs)
   3445           self._function_cache.primary[cache_key] = graph_function
   3446 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3287             arg_names=arg_names,
   3288             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3289             capture_by_value=self._capture_by_value),
   3290         self._function_attributes,
   3291         function_spec=self.function_spec,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    997         _, original_func = tf_decorator.unwrap(python_func)
    998 
--> 999       func_outputs = python_func(*func_args, **func_kwargs)
   1000 
   1001       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    670         # the function a weak reference to itself to avoid a reference cycle.
    671         with OptionalXlaContext(compile_with_xla):
--> 672           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    673         return out
    674 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in bound_method_wrapper(*args, **kwargs)
   3969     # However, the replacer is still responsible for attaching self properly.
   3970     # TODO(mdan): Is it possible to do it here instead?
-> 3971     return wrapped_fn(*args, **kwargs)
   3972   weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
   3973 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    984           except Exception as e:  # pylint:disable=broad-except
    985             if hasattr(e, "ag_error_metadata"):
--> 986               raise e.ag_error_metadata.to_exception(e)
    987             else:
    988               raise

ValueError: in user code:

    <ipython-input-11-215338bd29a6>:15 __call__  *
        self.model(inputs, training=trainable)
    <ipython-input-5-b312b9cd7f3b>:20 call  *
        outputs = self.saved_model(x[0], trainable=False)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/load.py:670 _call_attribute  **
        return instance.__call__(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:889 __call__
        result = self._call(*args, **kwds)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:924 _call
        results = self._stateful_fn(*args, **kwds)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:3022 __call__
        filtered_flat_args) = self._maybe_define_function(args, kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:3444 _maybe_define_function
        graph_function = self._create_graph_function(args, kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:3289 _create_graph_function
        capture_by_value=self._capture_by_value),
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:999 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:672 wrapped_fn
        out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/function_deserialization.py:291 restored_function_body
        "\n\n".join(signature_descriptions)))

    ValueError: Could not find matching function to call loaded from the SavedModel. Got:
      Positional arguments (2 total):
        * Tensor("inputs:0", shape=(None, None, 3), dtype=float32)
        * False
      Keyword arguments: {}
    
    Expected these arguments to match one of the following 2 option(s):
    
    Option 1:
      Positional arguments (2 total):
        * TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='inputs')
        * True
      Keyword arguments: {}
    
    Option 2:
      Positional arguments (2 total):
        * TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='inputs')
        * False
      Keyword arguments: {}

This is how my Model is defined:


class Model(tf.keras.Model):
    def __init__(self, path):
        super(Model, self).__init__()
        self.saved_model = tf.saved_model.load(path)
        self.dense_layer = tf.keras.layers.Dense(units=num_classes, name="head_supervised_new")
        self.optimizer = LARSOptimizer(
          learning_rate,
          momentum=momentum,
          weight_decay=weight_decay,
          exclude_from_weight_decay=['batch_normalization', 'bias', 'head_supervised'])

    def call(self, x):
        with tf.GradientTape() as tape:
            outputs = self.saved_model(x[0], trainable=False)
            logits_t = self.dense_layer(outputs['final_avg_pool'])
            one_hot_labels = tf.one_hot(tf.reshape(x[1], [-1]), num_classes)
            loss_t = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = one_hot_labels, logits=logits_t))
            dense_layer_weights = self.dense_layer.trainable_weights
            grads = tape.gradient(loss_t, dense_layer_weights)
            self.optimizer.apply_gradients(zip(grads, dense_layer_weights))
        return loss_t, x[0], logits_t, tf.reshape(x[1], [-1])

model = Model("gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/")

# Remove this for debugging.  
@tf.function
def train_step(x):
    return model(x)

Appreciate any help given! :D

junhuplim avatar Nov 01 '21 10:11 junhuplim