probability icon indicating copy to clipboard operation
probability copied to clipboard

Tensorflow CoLab code in Tutorial Doesn't Work

Open naji-s opened this issue 3 years ago • 0 comments

When going through the tutorial Bayesian Modeling with Joint Distribution with its CoLab, running line 27, i.e.

@_make_val_and_grad_fn
def neg_log_likelihood(x):
  # Generate a function closure so that we are computing the log_prob
  # conditioned on the observed data. Note also that tfp.optimizer.* takes a 
  # single tensor as input, so we need to do some slicing here:
  return -tf.squeeze(mdl_studentt.log_prob(
      mapper.split_and_reshape(x) + [Y_np]))

lbfgs_results = tfp.optimizer.lbfgs_minimize(
    neg_log_likelihood,
    initial_position=mapper.flatten_and_concat(mdl_studentt.sample()[:-1]),
    tolerance=1e-20,
    x_tolerance=1e-20
)

Returns the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-38-7343a78d6c2d>](https://localhost:8080/#) in <module>()
     11     initial_position=mapper.flatten_and_concat(mdl_studentt.sample()[:-1]),
     12     tolerance=1e-20,
---> 13     x_tolerance=1e-20
     14 )

14 frames
[/usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/distributions/joint_distribution.py](https://localhost:8080/#) in maybe_check_wont_broadcast(flat_xs, validate_args)
   1315   if all(ps.is_numpy(s_) for s_ in s):
   1316     if not all(same_shape(a, b) for a, b in zip(s[1:], s[:-1])):
-> 1317       raise ValueError(msg)
   1318     return flat_xs
   1319   assertions = [assert_util.assert_equal(a, b, message=msg)

ValueError: Broadcasting probably indicates an error in model specification.


The problem can be solved by changing

def gen_studentt_model(X, sigma,
                       hyper_mean=0, hyper_scale=1, lower=1, upper=100):
  loc = tf.cast(hyper_mean, dtype)
  scale = tf.cast(hyper_scale, dtype)
  low = tf.cast(lower, dtype)
  high = tf.cast(upper, dtype)
  return tfd.JointDistributionSequential([
      # b0 ~ Normal(0, 1)
      tfd.Sample(tfd.Normal(loc, scale), sample_shape=1),
      # b1 ~ Normal(0, 1)
      tfd.Sample(tfd.Normal(loc, scale), sample_shape=1),
      # df ~ Uniform(a, b)
      tfd.Sample(tfd.Uniform(low, high), sample_shape=1),
      # likelihood ~ StudentT(df, f(b0, b1), sigma_y)
      #   Using Independent to ensure the log_prob is not incorrectly broadcasted.
      lambda df, b1, b0: tfd.Independent(
          tfd.StudentT(df=df, loc=b0 + b1*X, scale=sigma), reinterpreted_batch_ndims=2),
  ], validate_args=True)

to

def gen_studentt_model(X, sigma,
                       hyper_mean=0, hyper_scale=1, lower=1, upper=100):
  loc = tf.cast(hyper_mean, dtype)
  scale = tf.cast(hyper_scale, dtype)
  low = tf.cast(lower, dtype)
  high = tf.cast(upper, dtype)
  return tfd.JointDistributionSequential([
      # b0 ~ Normal(0, 1)
      tfd.Sample(tfd.Normal(loc, scale), sample_shape=1),
      # b1 ~ Normal(0, 1)
      tfd.Sample(tfd.Normal(loc, scale), sample_shape=1),
      # df ~ Uniform(a, b)
      tfd.Sample(tfd.Uniform(low, high), sample_shape=1),
      # likelihood ~ StudentT(df, f(b0, b1), sigma_y)
      #   Using Independent to ensure the log_prob is not incorrectly broadcasted.
      lambda df, b1, b0: tfd.Independent(
          tfd.StudentT(df=df, loc=b0 + b1*X, scale=sigma), reinterpreted_batch_ndims=2),
  ], validate_args=True)

i.e. by adding the argument "reinterpreted_batch_ndims=1" which I only figured out after trial-and-error, and particularly paying attention to the warning that said not passing reinterpreted_batch_ndims will be deprecated. But I don't currently understand why this is missing? So I would appreciate any clarification. One also needs to change

mdl_studentt = gen_studentt_model(X_np[tf.newaxis, ...],
                                  sigma_y_np[tf.newaxis, ...])

to

mdl_studentt = gen_studentt_model(X_np,
                                  sigma_y_np)

To make the broadcasting work, but again I am not following the logic. Any help would be appreciated.

naji-s avatar Mar 11 '22 19:03 naji-s