probability icon indicating copy to clipboard operation
probability copied to clipboard

What would be a good way to get sample mean of joint distribution?

Open xiaolongluo1979 opened this issue 4 years ago • 6 comments

It appears the .mean() method is not implemented for customized joint distribution? We may use sampling method to estimate but it does not appear stable. Can anyone suggest an effective method?

For example, we take the guide book distribution:

joint = tfd.JointDistributionNamedAutoBatched(dict( e= tfd.Exponential(rate=[100, 120]), g=lambda e: tfd.Gamma(concentration=e[0], rate=e[1]), n= tfd.Normal(loc=0, scale=2.), m=lambda n, g: tfd.Normal(loc=n, scale=g), x=lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12), ))

We would expect joint.mean() to give marginal mean of each component, but it is not implemented per error message. We use sampling method as follow:

joint.mean() not implemented

small_sample=joint.sample(2000) large_sample=joint.sample(3000) for k in small_sample.keys(): print('Smaller sample mean',k,tf.reduce_mean(small_sample[k],axis=0).numpy())

print()

for k in large_sample.keys(): print('Larger sample mean',k,tf.reduce_mean(large_sample[k],axis=0).numpy())

However, estimates from different sample sizes are quite different for sequential variables (e.g., n and m): Smaller sample mean e [0.00983032 0.00844856] Smaller sample mean g 9.612298 Smaller sample mean n -0.028275907 Smaller sample mean m 4.339757 Smaller sample mean x [0 0 0 0 0 0 0 0 0 0 0 0]

Larger sample mean e [0.00985759 0.00873891] Larger sample mean g 12.751243 Larger sample mean n -0.0056907604 Larger sample mean m -14.086281 Larger sample mean x [0 0 0 0 0 0 0 0 0 0 0 0]

Are there any better method to get marginal means?

Thanks, Xiaolong

xiaolongluo1979 avatar Dec 31 '21 23:12 xiaolongluo1979

You can use joint.sample_distributions() to return both samples and an instance of each component distribution (conditioned on the samples from the distributions on which it depends), and then can call .mean() on each component distribution.

Note that, to be able to call sample_distributions(sample_shape=...) with non-trivial sample_shape, you will need to use JointDistributionNamed instead of JointDistributionNamedAutoBatched.

For example:

joint = tfd.JointDistributionNamed(dict(
    e=tfd.Exponential(rate=[100, 120]),
    n=tfd.Normal(loc=0., scale=2.),
    g=lambda e: tfd.Gamma(concentration=e[..., 0], rate=e[..., 1]),
    m=lambda n, g: tfd.Normal(loc=n, scale=g),
    x=lambda m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))

num_samples = 100000

ds, xs = joint.sample_distributions(num_samples)

# ds is:
# {'e': <tfp.distributions.Exponential 'Exponential' batch_shape=[2] event_shape=[] dtype=float32>,
#  'g': <tfp.distributions.Gamma 'Gamma' batch_shape=[100000] event_shape=[] dtype=float32>,
#  'm': <tfp.distributions.Normal 'Normal' batch_shape=[100000] event_shape=[] dtype=float32>,
#  'n': <tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>,
#  'x': <tfp.distributions.Sample 'SampleBernoulli' batch_shape=[100000] event_shape=[12] dtype=int32>}

for k, d in ds.items():
  mean = d.mean()
  if mean.shape[:1] == [num_samples]:
    mean = tf.reduce_mean(mean, axis=0)
  print('{}: {}'.format(k, mean))

# example output:
#  e: [0.01       0.00833333]
#  g: 11.318492889404297
#  n: 0.0
#  m: 0.009445535950362682
#  x: [0.50191605 0.50191605 0.50191605 0.50191605 0.50191605 0.50191605
#   0.50191605 0.50191605 0.50191605 0.50191605 0.50191605 0.50191605]

jburnim avatar Jan 06 '22 01:01 jburnim

@jburnim Thanks for looking at this. Where did you use xs in previous codes? I thought the population marginal mean could be estimated by the component mean of xs? How does d.mean() depend on previous samples? I may be missing something.

I added a line of sample mean from xs. They appear different from d.mean():

for k, d in ds.items():
  mean = d.mean()
  if mean.shape[:1] == [num_samples]:
    mean = tf.reduce_mean(mean, axis=0)
  sam_mean=tf.reduce_mean(xs[k], axis=0)
  print('{}: {}'.format(k, mean))
  print('Joint samples mean {}: {}'.format(k, sam_mean))

output:

e: [0.01       0.00833333]
Joint samples mean e: [0.00993393 0.00833252]
g: 20.211729049682617
Joint samples mean g: 7.160811424255371
n: 0.0
Joint samples mean n: 0.0034047598019242287
m: 0.0034047598019242287
Joint samples mean m: 0.5829066634178162
x: [0.50019854 0.50019854 0.50019854 0.50019854 0.50019854 0.50019854
 0.50019854 0.50019854 0.50020075 0.50020075 0.50020075 0.50020075]
Joint samples mean x: [0 0 0 0 0 0 0 0 0 0 0 0]

xiaolongluo1979 avatar Jan 06 '22 21:01 xiaolongluo1979

In this case, joint.sample_distributions(num_samples) is roughly equivalent to:

distributions = {}
samples = {}

distributions['e'] = tfd.Exponential([100, 200])
samples['e'] = distributions['e'].sample(num_samples)

distributions['n'] = tfd.Normal(loc=0., scale=2.)
samples['n'] = distributions['n'].sample(num_samples)

distributions['g'] = tfd.Gamma(concentration=samples['e'][..., 0], rate=samples['e'][..., 1])
samples['g'] = distributions['g'].sample()

distributions['m'] = tfd.Normal(loc=samples['n'], scale=samples['g'])
samples['m'] = distributions['m'].sample()

distributions['x'] = tfd.Sample(tfd.Bernoulli(logits=samples['m']), 12)
samples['x'] = distributions['x'].sample()

return distributions, samples

In other words, joint.sample_distributions performs the same sampling procedure as joint.sample, but it returns all of the component distributions built during sample, in addition to returning the samples.

jburnim avatar Jan 09 '22 22:01 jburnim

Hi @jburnim I think I know what you mean, but technically the codes you included seem to have some issues when i ran. Maybe it was due to some typo?

AttributeError Traceback (most recent call last) in () 12 13 distributions['m'] = tfd.Normal(loc=samples['n'], scale=samples['g']), ---> 14 samples['m'] = distributions['m'].sample(num_samples) 15 16 distributions['x'] = tfd.Sample(tfd.Bernoulli(logits=samples['m']), 12),

AttributeError: 'tuple' object has no attribute 'sample'

It seem to double batches? Also, the sequential use of d.sample(num_samples) seems to grow exponentially? Not sure it is feasible with large num_samples.

xiaolongluo1979 avatar Jan 10 '22 14:01 xiaolongluo1979

Good catch -- there are two typos (extra commas at the end of lines distributions['m'] = ... and distributions['x'] = ...), and, to get the correct shapes, num_samples should only have been passed to a subset of the sample calls. I've edited the code snippet above to fix these issues.

jburnim avatar Jan 10 '22 17:01 jburnim

Hi @jburnim , thanks for fixing the typos. I am still not sure about which way to get the marginal mean, say, 'm' component.

We can take the mean of the conditional distribution mean, distributions['m'].mean(), given e, n, g, or we can take the marginal mean of the joint distribution samples. Both sound reasonable given the chain structure. However, the number don't appear close even for large num_samples. See codes here: (please note I modified the 'e' parameters so that they are not practically zeros.)

num_samples=20000

distributions = {} samples = {}

distributions['e'] = tfd.Exponential([10, 20]) samples['e'] = distributions['e'].sample(num_samples)

distributions['n'] = tfd.Normal(loc=0., scale=2.) samples['n'] = distributions['n'].sample(num_samples)

distributions['g'] = tfd.Gamma(concentration=samples['e'][..., 0], rate=samples['e'][..., 1]) samples['g'] = distributions['g'].sample()

distributions['m'] = tfd.Normal(loc=samples['n'], scale=samples['g']) samples['m'] = distributions['m'].sample()

distributions['x'] = tfd.Sample(tfd.Bernoulli(logits=samples['m']), 12) samples['x'] = distributions['x'].sample()

print('Which should be used as the marginal mean of m distribution \n', tf.reduce_mean(distributions['m'].mean()).numpy(), ' or \n',tf.reduce_mean(distributions['m'].sample()).numpy())

return distributions, samples

Output:

Which should be used as the marginal mean of m distribution -0.0077947374 or -2.6458519

let me know what you think. Thanks.

xiaolongluo1979 avatar Jan 10 '22 19:01 xiaolongluo1979