probability icon indicating copy to clipboard operation
probability copied to clipboard

Error when saving model using a `DistributionLambda` layer

Open mkretsch327 opened this issue 3 years ago • 13 comments

When saving a keras model that incorporates a DistributionLambda layer from tfp.layers, I receive a stack trace ending in the following (complete stack trace at end of post). I observe this error with tfp v0.12.2 and tf version 2.5.0. However, it doesn't happen with tfp v0.12.2 and tf version 2.4.1.

  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 489, in _disallow_when_autograph_enabled
    raise errors.OperatorNotAllowedInGraphError(
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

Please let me know if this better belongs as an issue over on the main tensorflow repo!

Installed package versions

    tensorflow==2.5.0
    tensorflow-probability==0.12.2

Python version: 3.8.8

Script to recreate

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

tfd = tfp.distributions

model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(10))
model.add(tf.keras.layers.Dense(2, activation="linear"))
model.add(
    tfp.layers.DistributionLambda(
        lambda t: tfd.Normal(
            loc=t[..., :1], scale=1e-3 + tf.math.softplus(0.1 * t[..., 1:])
        )
    )
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss="mean_absolute_error",
    # List of metrics to monitor
    metrics="mean_absolute_error",
)
model.save("~/tf_test_model/")

Complete stack trace

Traceback (most recent call last):
  File "utilities/tf_test_script.py", line 23, in <module>
    model.save("~/tf_test_model/")
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2111, in save
    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py", line 150, in save_model
    saved_model_save.save(model, filepath, overwrite, include_optimizer,
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py", line 89, in save
    saved_nodes, node_paths = save_lib.save_and_return_nodes(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1103, in save_and_return_nodes
    _build_meta_graph(obj, signatures, options, meta_graph_def,
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1290, in _build_meta_graph
    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 1207, in _build_meta_graph_impl
    signatures = signature_serialization.find_function_to_export(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py", line 99, in find_function_to_export
    functions = saveable_view.list_functions(saveable_view.root)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py", line 154, in list_functions
    obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 2713, in _list_functions_for_serialization
    functions = super(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 3016, in _list_functions_for_serialization
    return (self._trackable_saved_model_saver
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py", line 92, in list_functions_for_serialization
    fns = self.functions_to_serialize(serialization_cache)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 73, in functions_to_serialize
    return (self._get_serialized_attributes(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 89, in _get_serialized_attributes
    object_dict, function_dict = self._get_serialized_attributes_internal(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py", line 53, in _get_serialized_attributes_internal
    super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py", line 99, in _get_serialized_attributes_internal
    functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 204, in wrap_layer_functions
    fn.get_concrete_function()
  File "/miniconda/envs/optimus/lib/python3.8/contextlib.py", line 120, in __exit__
    next(self.gen)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 367, in tracing_scope
    fn.get_concrete_function(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1367, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 1284, in _get_concrete_function_garbage_collected
    concrete = self._stateful_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3100, in _get_concrete_function_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 599, in wrapper
    ret = method(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 165, in wrap_with_training_arg
    return control_flow_util.smart_cond(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py", line 109, in smart_cond
    return smart_module.smart_cond(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 166, in <lambda>
    training, lambda: replace_training_and_call(True),
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 163, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 681, in call
    return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 639, in __call__
    return self.wrapped_call(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 924, in _call
    results = self._stateful_fn(*args, **kwds)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3022, in __call__
    filtered_flat_args) = self._maybe_define_function(args, kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 599, in wrapper
    ret = method(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 165, in wrap_with_training_arg
    return control_flow_util.smart_cond(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py", line 109, in smart_cond
    return smart_module.smart_cond(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 166, in <lambda>
    training, lambda: replace_training_and_call(True),
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py", line 163, in replace_training_and_call
    return wrapped_call(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py", line 663, in call_and_return_conditional_losses
    call_output = layer_call(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py", line 380, in call
    return super(Sequential, self).call(inputs, training=training, mask=mask)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 420, in call
    return self._run_internal_graph(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py", line 556, in _run_internal_graph
    outputs = node.layer(*args, **kwargs)
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py", line 245, in __call__
    distribution, _ = super(DistributionLambda, self).__call__(
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 520, in __iter__
    self._disallow_iteration()
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 513, in _disallow_iteration
    self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
  File "/miniconda/envs/optimus/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 489, in _disallow_when_autograph_enabled
    raise errors.OperatorNotAllowedInGraphError(
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

mkretsch327 avatar Jun 03 '21 20:06 mkretsch327

Hey Matt! I can reproduce, and add that it also happens with the nightly tensorflow/tfp, and with the release candidate and tensorflow==2.5.0, but I don't yet have good leads on fixing it.

ColCarroll avatar Jun 04 '21 15:06 ColCarroll

Thanks @ColCarroll ! I wish I had a better idea to help on root cause, will keep an eye out here.

mkretsch327 avatar Jun 07 '21 16:06 mkretsch327

Same error here, cannot go for tensorflow 2.5 because of this. I am using python 3.9, though. I cannot save any model with DistributionLamda.

cserpell avatar Jun 12 '21 16:06 cserpell

If it helps, I checked that the actual call to __call__ in Layer gets completely executed and then, when it returns, the calls to __exit__ of the CallContextManagers begin to be called, but it seems that the error already happened. I tried changing the check in base_layer_utils.is_subclassed, called from __call__, trying to force not using autograph, based on your script (no idea what I am doing), and it didn't work.

cserpell avatar Jun 14 '21 03:06 cserpell

Hi, is there any hint on this? It continues happening, according to this example I did. Is #1382 related?

cserpell avatar Aug 03 '21 03:08 cserpell

As a workaround, I confirm that weights can be stored and then loaded back in another declared model:

the_model = models.Model(...)
the_model.save_weights('file_name')
...
another_model = models.Model(...)
another_model.load_weights('file_name')

cserpell avatar Aug 03 '21 15:08 cserpell

I had a similar problem, but when I tried it with tfp v0.12.2 and tf version 2.4.1 it gave me another error: https://github.com/tensorflow/probability/issues/1268 I also posted there the versions I used to make it work.

aciobanusebi avatar Sep 04 '21 09:09 aciobanusebi

Also facing this issue using tfp==0.13.0 and tf==2.5.0

MathiasGruber avatar Sep 09 '21 16:09 MathiasGruber

Found a workaround for this issue using tf==2.5.0 and tfp==0.12.2 (have not tested other versions):

By wrapping the call to tfp.layers.DistributionLambda in a user-defined layer class, you can pass that class to the custom_objects argument of tf.keras.models.load_model. The model will save and load as normal.

Here's an example snippet I wrote for a model whose output layer is a 2D Gaussian RV:

class Normal2DLayer(tf.keras.layers.layer)
    def __init__(self):
        super(Normal2DLayer, self).__init__()
        self.loc = tf.keras.layers.Dense(2, name='loc')
        self.scale = tf.keras.layers.Dense(2, name='scale')
        self.rotation = tf.keras.layers.Dense(1, name='rotation')     # Parameterizing a 2D Gaussian as a rotated ellipsoid

    def call(self, inputs):
        loc = self.loc(inputs)
        scale_vec = self.scale(inputs)
        theta = self.rotation(inputs)
        scale_diag = tf.linalg.diag(scale_vec)
        rotation_mat = tf.reshape(tf.stack([tf.math.cos(theta), -tf.math.sin(theta), tf.math.sin(theta), tf.math.cos(theta)], axis=1), (tf.shape(theta)[0], 2, 2))
        scale_mat = tf.matmul(rotation_mat, scale_diag)
        params = [loc, scale_mat]
        density = tfp.layers.DistributionLambda(
            make_distribution_fn = lambda params: tfd.MultivariateNormalLinearOperator(loc=params[0], scale=tf.linalg.LinearOperatorFullMatrix(params[1])),
        )(params)
        return density

    def get_config(self):
        config = dict()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

From here, as long as I pass the class definition to the custom_objects argument, the normal keras save and load_model functions work as expected:

model.save(path_to_model)
model = tf.keras.models.load_model(path_to_model, custom_objects={"Normal2DLayer" : Normal2DLayer})

jdbates avatar Oct 09 '21 23:10 jdbates

I am newbie (like week of self of learning), but getting same error with this example and with other unrelated one ("Probabilistic Bayesian Neural Networks" Git repo. When I just add save() after training. ) It uses DistributionLambda too and it seems to me this DistributionLambda.class doesn't implement/override some necessary function. (python3 -c 'import tensorflow as tf; print(tf.version)' output: 2.7.0 )

Martinc4321 avatar Dec 31 '21 10:12 Martinc4321

I am using TF2.7 and TFP0.15. If I pass the save_traces=False argument to model.save() it works.

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers

X = np.random.randn(100,10)
y = np.random.randn(100,1)

n_features = X.shape[-1]
n_outputs = y.shape[-1]

inputs = layers.Input(shape = (n_features,), name='features')
x = layers.Dense(16, activation='relu')(inputs)
x = layers.Dense(2)(x)
outputs = tfp.layers.IndependentNormal(n_outputs,  name='outputs')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.save("mymodel", save_traces=False)

del model

model = tf.keras.models.load_model("mymodel")
model.summary()

I found this in the Tensorflow documentation

The argument save_traces has been added to model.save, which allows you to toggle SavedModel function tracing. Functions are saved to allow the Keras to re-load custom objects without the original class definitons, so when save_traces=False, all custom objects must have defined get_config/from_config methods.

So I guess as long as we use classes that have the get_config/from_config methods, we should be fine. And as we can see here, the DistributionLambda layer that serves as base for many tfp layers, has these methods.

frazane avatar Feb 23 '22 09:02 frazane

I have the same issue using tensorflow 2.7.0 and tensorflow-probability 0.15.0. The suggestion from @frazane to use save_traces=False did not work for me. But the solution by @jdbates to use a custom layer did.

willemvdb42 avatar May 02 '22 12:05 willemvdb42

I am using TF2.7 and TFP0.15. If I pass the save_traces=False argument to model.save() it works.

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import layers

X = np.random.randn(100,10)
y = np.random.randn(100,1)

n_features = X.shape[-1]
n_outputs = y.shape[-1]

inputs = layers.Input(shape = (n_features,), name='features')
x = layers.Dense(16, activation='relu')(inputs)
x = layers.Dense(2)(x)
outputs = tfp.layers.IndependentNormal(n_outputs,  name='outputs')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.save("mymodel", save_traces=False)

del model

model = tf.keras.models.load_model("mymodel")
model.summary()

I found this in the Tensorflow documentation

The argument save_traces has been added to model.save, which allows you to toggle SavedModel function tracing. Functions are saved to allow the Keras to re-load custom objects without the original class definitons, so when save_traces=False, all custom objects must have defined get_config/from_config methods.

So I guess as long as we use classes that have the get_config/from_config methods, we should be fine. And as we can see here, the DistributionLambda layer that serves as base for many tfp layers, has these methods.

this worked for me, thanks a lot, the original error I had was: OperatorNotAllowedInGraphError: Iterating over a symbolic tf.Tensor is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

I am adding it here, because it seems to be different from OP's error but the problem was the same.

edit: while I was able to save it, I wasn't able to load it @frazane I actually get the error that the methods are missing: (1) Implement get_configandfrom_configin the layer/model class, and pass the object to thecustom_objects argument when loading the model. For more details, see: https://www.tensorflow.org/guide/keras/save_and_serialize

I presume they were added in later editions of tfp

aegonwolf avatar Dec 19 '22 15:12 aegonwolf