probability icon indicating copy to clipboard operation
probability copied to clipboard

Error when inheriting from JointDistributionSequential in tfp version 0.18.0

Open gileshd opened this issue 2 years ago • 3 comments

In the most recent update (0.18.0) instantiating a class which inherits from tfp.distributions.JointDistributionSequential now raises an error.

Here is a MWE (reproduced in colab):

from jax import numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

class HierarchicalNormal(tfd.JointDistributionSequential):
        def __init__(self, loc, scale):
            self.loc = loc
            self.prior_scale = scale

            super().__init__(model=[
                tfd.Normal(loc,scale),
                lambda mu: tfd.Normal(mu, scale)
                ])


loc = jnp.zeros(1)
scale = jnp.ones(1)
dist = HierarchicalNormal(loc,scale)

raises the following error:

    388     if not isinstance(model, collections.abc.Sequence):
    389       raise TypeError('`model` must be `list`-like (saw: {}).'.format(
--> 390           type(model).__name__))
    391     self._dist_fn = model
    392     self._dist_fn_wrapped, self._dist_fn_args = zip(*[

TypeError: `model` must be `list`-like (saw: DeviceArray).

Note that the following works:

loc = jnp.zeros(1)
scale = jnp.ones(1)
model=[tfd.Normal(loc,scale),
       lambda mu: tfd.Normal(mu, scale)] 

dist = tfd.JointDistributionSequential(model)

This seems to be new in 0.18.0 with no errors in versions 0.17.0 or 0.16.0.

(tagging @murphyk @slinderman to be kept up to date).

gileshd avatar Sep 13 '22 17:09 gileshd

This colab makes it clear that the problem does not exist in 0.16 (colab default) but does exist in 0.18.

murphyk avatar Sep 13 '22 21:09 murphyk

As a workaround, I think you can add this method to your class:

def __new__(cls, *args, **kwargs):
  return tfd.Distribution.__new__(cls)

SiegeLordEx avatar Sep 13 '22 22:09 SiegeLordEx

Hi @SiegeLordEx Yes, that fixes it - thanks!

murphyk avatar Sep 13 '22 23:09 murphyk