probability icon indicating copy to clipboard operation
probability copied to clipboard

`JointDistributionCoroutineAutoBatched` seed behaviour with batches

Open chrism0dwk opened this issue 6 months ago • 0 comments

Hi all,

I'm trying to compute a posterior predictive distribution over samples from a posterior distribution (Colab here). TFP 0.25 with JAX backend.

My (mre and therefore contrived) model specification is

@tfd.JointDistributionCoroutineAutoBatched
def model_autobatched():
    theta = yield tfd.Normal(loc=0., scale=1., name="theta")
    yield tfd.Normal(loc=theta, scale=0.1, name="y")

i.e. a Normally-distributed observation model with Normally-distributed mean. To compute the posterior predictive distribution, I wish to sample the y component conditional on a vector of theta samples.

theta_samples = np.arange(5.)
model_autobatched.sample(theta=theta_samples, seed=jax.random.key(0))

giving

StructTuple(
  theta=Array([0., 1., 2., 3., 4.], dtype=float32),
  y=Array([0.06215769, 1.0621576 , 2.0621576 , 3.0621576 , 4.0621576 ],      dtype=float32)
)

Oh dear, we notice that y - theta = constant. This seems to suggest that a single PRNG key is being used for each draw of y given the sample from theta.

Moreover, this approach fails entirely if sample_distributions is called.

model_autobatched.sample_distributions(theta=theta_samples, seed=jax.random.key(0))
ValueError: Attempt to convert a value (<object object at 0x7a53561590d0>) with an unsupported type (<class 'object'>) to a Tensor.

~~As a workaround, we could use the older JointDistributionCoroutine with Root annotation which works as desired (see Colab)~~

[edit] actually, JDCoroutine/Root only works because the whole theta vector is passed to y's constructor, not vectorisation over the whole model.

Do we have a bug or a feature, I wonder?

Regards,

Chris

chrism0dwk avatar Aug 21 '25 08:08 chrism0dwk