probability
probability copied to clipboard
Feature request: 2-D weights option for log prob
I am working on a VAE with a Gaussian decoder using tensorflow probability. I calculate the log likelihood using x_hat.log_prob(x)
where x_hat
is the output of the decoder (tfp.distributions.Distribution
). However, in my application it is important to apply weights to the values from the log_prob
summation using a weights
tensor of floats whose shape matches x_true
i.e weights.shape=(batch_size, features)
. So this weights
tensor will apply a different weight to each feature of each sample.
Essentially I want to do what is being done in the MSE example below:
MSE = tf.reduce_sum(0.5 * weights * (x - recon_x)**2)
where x
, recon_x
and weights
all have shape (batch_size, features)
.
I tried looking for a way to get the log_prob
before the summation so that I can apply the weights and then sum but I couldn’t find something relevant in the documentation.
Is there another way to go about doing this?
You'd have to write your own wrapper distribution. tfd.Masked
is similar in principle, but only supports 0-1 weights. It could be an inspiration, although really all you need to do is something like this:
class ScaleLogProb(tfd.Distribution):
def __init__(self,
distribution,
weight,
validate_args=False,
allow_nan_stats=True,
name=None):
parameters = dict(locals())
with tf.name_scope(name or ScaleLogProb{distribution.name}') as name:
self._distribution = distribution
self._weight = tensor_util.convert_nonref_to_tensor(
weight, dtype_hint=tf.float32)
super().__init__(
dtype=distribution.dtype,
reparameterization_type=distribution.reparameterization_type,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=name)
@classmethod
def _parameter_properties(cls, dtype, num_classes=None):
return dict(
distribution=parameter_properties.BatchedComponentProperties(),
weight=parameter_properties.ParameterProperties(
shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED))
def _event_shape(self):
return self.distribution.event_shape
def _event_shape_tensor(self):
return self.distribution.event_shape_tensor()
def _log_prob(self, x):
return self.weight * self.distribution.log_prob(x)
And then you'd use it as:
d = tfd.Independent(ScaleLogProb(tfd.Normal(recon_x, 1.), weights), 1)
MSE = -d.log_prob(x)
For a proper TFP contribution, you'd have to probably disable sampling from such a distribution.