probability icon indicating copy to clipboard operation
probability copied to clipboard

Feature request: 2-D weights option for log prob

Open ConstantinaNicolaou opened this issue 2 years ago • 1 comments

I am working on a VAE with a Gaussian decoder using tensorflow probability. I calculate the log likelihood using x_hat.log_prob(x) where x_hat is the output of the decoder (tfp.distributions.Distribution). However, in my application it is important to apply weights to the values from the log_prob summation using a weights tensor of floats whose shape matches x_true i.e weights.shape=(batch_size, features). So this weights tensor will apply a different weight to each feature of each sample.

Essentially I want to do what is being done in the MSE example below:

MSE = tf.reduce_sum(0.5 * weights * (x - recon_x)**2)

where x, recon_x and weights all have shape (batch_size, features).

I tried looking for a way to get the log_prob before the summation so that I can apply the weights and then sum but I couldn’t find something relevant in the documentation.

Is there another way to go about doing this?

ConstantinaNicolaou avatar May 11 '22 13:05 ConstantinaNicolaou

You'd have to write your own wrapper distribution. tfd.Masked is similar in principle, but only supports 0-1 weights. It could be an inspiration, although really all you need to do is something like this:

class ScaleLogProb(tfd.Distribution):
  def __init__(self,
               distribution,
               weight,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    parameters = dict(locals())
    with tf.name_scope(name or ScaleLogProb{distribution.name}') as name:
      self._distribution = distribution
      self._weight = tensor_util.convert_nonref_to_tensor(
          weight, dtype_hint=tf.float32)
      super().__init__(
          dtype=distribution.dtype,
          reparameterization_type=distribution.reparameterization_type,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          name=name)

  @classmethod
  def _parameter_properties(cls, dtype, num_classes=None):
    return dict(
        distribution=parameter_properties.BatchedComponentProperties(),
        weight=parameter_properties.ParameterProperties(
            shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED))

  def _event_shape(self):
    return self.distribution.event_shape

  def _event_shape_tensor(self):
    return self.distribution.event_shape_tensor()

  def _log_prob(self, x):
    return self.weight * self.distribution.log_prob(x)

And then you'd use it as:

d = tfd.Independent(ScaleLogProb(tfd.Normal(recon_x, 1.), weights), 1)
MSE = -d.log_prob(x)

For a proper TFP contribution, you'd have to probably disable sampling from such a distribution.

SiegeLordEx avatar May 11 '22 18:05 SiegeLordEx