probability
probability copied to clipboard
Error when inheriting from JointDistributionSequential in tfp version 0.18.0
In the most recent update (0.18.0) instantiating a class which inherits from tfp.distributions.JointDistributionSequential
now raises an error.
Here is a MWE (reproduced in colab):
from jax import numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
class HierarchicalNormal(tfd.JointDistributionSequential):
def __init__(self, loc, scale):
self.loc = loc
self.prior_scale = scale
super().__init__(model=[
tfd.Normal(loc,scale),
lambda mu: tfd.Normal(mu, scale)
])
loc = jnp.zeros(1)
scale = jnp.ones(1)
dist = HierarchicalNormal(loc,scale)
raises the following error:
388 if not isinstance(model, collections.abc.Sequence):
389 raise TypeError('`model` must be `list`-like (saw: {}).'.format(
--> 390 type(model).__name__))
391 self._dist_fn = model
392 self._dist_fn_wrapped, self._dist_fn_args = zip(*[
TypeError: `model` must be `list`-like (saw: DeviceArray).
Note that the following works:
loc = jnp.zeros(1)
scale = jnp.ones(1)
model=[tfd.Normal(loc,scale),
lambda mu: tfd.Normal(mu, scale)]
dist = tfd.JointDistributionSequential(model)
This seems to be new in 0.18.0 with no errors in versions 0.17.0 or 0.16.0.
(tagging @murphyk @slinderman to be kept up to date).
This colab makes it clear that the problem does not exist in 0.16 (colab default) but does exist in 0.18.
As a workaround, I think you can add this method to your class:
def __new__(cls, *args, **kwargs):
return tfd.Distribution.__new__(cls)
Hi @SiegeLordEx Yes, that fixes it - thanks!