probability
probability copied to clipboard
JointDistributionSequential cannot be serialized via tf.train.Checkpoint
I'm using tensorflow version 2.11.0 and tensorflow_probability 0.19.0.
The following code succeeds:
import tensorflow as tf
import tensorflow_probability as tfp
class MyModule(tf.Module):
def __init__(self):
self.dist = tfp.distributions.Normal(0.0, 1.0)
my_module = MyModule()
checkpoint = tf.train.Checkpoint(my_module)
checkpoint.save("checkpoint")
but the following code fails:
import tensorflow as tf
import tensorflow_probability as tfp
class MyModule(tf.Module):
def __init__(self):
self.dist = tfp.distributions.JointDistributionSequential(
[
tfp.distributions.Normal(0.0, 1.0),
tfp.distributions.Normal(0.0, 1.0),
]
)
my_module = MyModule()
checkpoint = tf.train.Checkpoint(my_module)
checkpoint.save("checkpoint")
with the error:
ValueError: Unable to save the object {-1: ListWrapper([<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>, <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>])} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary contains a non-string key which maps to a trackable object or mutable data structure.
If you don't need this dictionary checkpointed, wrap it in a non-trackable object; it will be subsequently ignored.
It seems to be an issue with the JointDistribution._single_sample_distributions dict having int keys. If I add
my_module.dist._single_sample_distributions = {}
before checkpointing then it succeeds.
You solved my problems. Thank you!