simclr
simclr copied to clipboard
Saving model after finetuning on colab tf2
Hi all,
I have finetuned a model on my custom dataset and would like to save the model for inference purposes. I realised this wasn't addressed in the colab examples.
I tried to look around how a model is saved in run.py: https://github.com/google-research/simclr/blob/dec99a81a4ceccb0a5a893afecbc2ee18f1d76c3/tf2/run.py#L287
but got the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-11-215338bd29a6> in <module>
34 tf.saved_model.save(saved_model, checkpoint_export_dir)
35
---> 36 save(model, 1)
<ipython-input-11-215338bd29a6> in save(model, global_step)
26 def save(model, global_step):
27 """Export as SavedModel for finetuning and inference."""
---> 28 saved_model = build_saved_model(model)
29 # export_dir = os.path.join(FLAGS.model_dir, 'saved_model')
30 export_dir = model_p
<ipython-input-11-215338bd29a6> in build_saved_model(model, include_projection_head)
18 module = SimCLRModel(model)
19 input_spec = tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)
---> 20 module.__call__.get_concrete_function(input_spec, trainable=True)
21 module.__call__.get_concrete_function(input_spec, trainable=False)
22 return module
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
1365 ValueError: if this object has not yet been called on concrete values.
1366 """
-> 1367 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1368 concrete._garbage_collector.release() # pylint: disable=protected-access
1369 return concrete
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
1271 if self._stateful_fn is None:
1272 initializers = []
-> 1273 self._initialize(args, kwargs, add_initializers_to=initializers)
1274 self._initialize_uninitialized_variables(initializers)
1275
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
762 self._concrete_stateful_fn = (
763 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 764 *args, **kwds))
765
766 def invalid_creator_scope(*unused_args, **unused_kwds):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
3048 args, kwargs = None, None
3049 with self._lock:
-> 3050 graph_function, _ = self._maybe_define_function(args, kwargs)
3051 return graph_function
3052
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442
3443 self._function_cache.missed.add(call_context_key)
-> 3444 graph_function = self._create_graph_function(args, kwargs)
3445 self._function_cache.primary[cache_key] = graph_function
3446
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3287 arg_names=arg_names,
3288 override_flat_arg_shapes=override_flat_arg_shapes,
-> 3289 capture_by_value=self._capture_by_value),
3290 self._function_attributes,
3291 function_spec=self.function_spec,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997 _, original_func = tf_decorator.unwrap(python_func)
998
--> 999 func_outputs = python_func(*func_args, **func_kwargs)
1000
1001 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670 # the function a weak reference to itself to avoid a reference cycle.
671 with OptionalXlaContext(compile_with_xla):
--> 672 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
673 return out
674
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in bound_method_wrapper(*args, **kwargs)
3969 # However, the replacer is still responsible for attaching self properly.
3970 # TODO(mdan): Is it possible to do it here instead?
-> 3971 return wrapped_fn(*args, **kwargs)
3972 weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
3973
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
984 except Exception as e: # pylint:disable=broad-except
985 if hasattr(e, "ag_error_metadata"):
--> 986 raise e.ag_error_metadata.to_exception(e)
987 else:
988 raise
ValueError: in user code:
<ipython-input-11-215338bd29a6>:15 __call__ *
self.model(inputs, training=trainable)
<ipython-input-5-b312b9cd7f3b>:20 call *
outputs = self.saved_model(x[0], trainable=False)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/load.py:670 _call_attribute **
return instance.__call__(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:889 __call__
result = self._call(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:924 _call
results = self._stateful_fn(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:3022 __call__
filtered_flat_args) = self._maybe_define_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:3444 _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:3289 _create_graph_function
capture_by_value=self._capture_by_value),
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:999 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:672 wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/function_deserialization.py:291 restored_function_body
"\n\n".join(signature_descriptions)))
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (2 total):
* Tensor("inputs:0", shape=(None, None, 3), dtype=float32)
* False
Keyword arguments: {}
Expected these arguments to match one of the following 2 option(s):
Option 1:
Positional arguments (2 total):
* TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='inputs')
* True
Keyword arguments: {}
Option 2:
Positional arguments (2 total):
* TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='inputs')
* False
Keyword arguments: {}
This is how my Model is defined:
class Model(tf.keras.Model):
def __init__(self, path):
super(Model, self).__init__()
self.saved_model = tf.saved_model.load(path)
self.dense_layer = tf.keras.layers.Dense(units=num_classes, name="head_supervised_new")
self.optimizer = LARSOptimizer(
learning_rate,
momentum=momentum,
weight_decay=weight_decay,
exclude_from_weight_decay=['batch_normalization', 'bias', 'head_supervised'])
def call(self, x):
with tf.GradientTape() as tape:
outputs = self.saved_model(x[0], trainable=False)
logits_t = self.dense_layer(outputs['final_avg_pool'])
one_hot_labels = tf.one_hot(tf.reshape(x[1], [-1]), num_classes)
loss_t = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = one_hot_labels, logits=logits_t))
dense_layer_weights = self.dense_layer.trainable_weights
grads = tape.gradient(loss_t, dense_layer_weights)
self.optimizer.apply_gradients(zip(grads, dense_layer_weights))
return loss_t, x[0], logits_t, tf.reshape(x[1], [-1])
model = Model("gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/")
# Remove this for debugging.
@tf.function
def train_step(x):
return model(x)
Appreciate any help given! :D