probability icon indicating copy to clipboard operation
probability copied to clipboard

JointDistributionSequential cannot be serialized via tf.train.Checkpoint

Open sid-kap opened this issue 3 years ago • 1 comments

I'm using tensorflow version 2.11.0 and tensorflow_probability 0.19.0.

The following code succeeds:

import tensorflow as tf
import tensorflow_probability as tfp

class MyModule(tf.Module):
    def __init__(self):
        self.dist = tfp.distributions.Normal(0.0, 1.0)
        
my_module = MyModule()

checkpoint = tf.train.Checkpoint(my_module)
checkpoint.save("checkpoint")

but the following code fails:

import tensorflow as tf
import tensorflow_probability as tfp

class MyModule(tf.Module):
    def __init__(self):
        self.dist = tfp.distributions.JointDistributionSequential(
            [
                tfp.distributions.Normal(0.0, 1.0),
                tfp.distributions.Normal(0.0, 1.0),
            ]
        )
        
my_module = MyModule()

checkpoint = tf.train.Checkpoint(my_module)
checkpoint.save("checkpoint")

with the error:

ValueError: Unable to save the object {-1: ListWrapper([<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>])} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary contains a non-string key which maps to a trackable object or mutable data structure.

If you don't need this dictionary checkpointed, wrap it in a non-trackable object; it will be subsequently ignored.

It seems to be an issue with the JointDistribution._single_sample_distributions dict having int keys. If I add

my_module.dist._single_sample_distributions = {}

before checkpointing then it succeeds.

sid-kap avatar Jan 12 '23 21:01 sid-kap

You solved my problems. Thank you!

brifin avatar Mar 09 '23 05:03 brifin