gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

Attempting to fix `MultivariateNormal.rsample`

Open rhaps0dy opened this issue 6 years ago • 6 comments

When mean and covariance have different (but compatible) batch dimensions, the previous code fails.

There are two things that I don't know how we should handle:

  1. Is it important that rsample works over base_samples that can be broadcasted but don't have the same shape? I have decided to not add this feature because it wasn't there, I don't need it and it would require somewhat larger implementation changes. Arguably it's also out of scope of what rsample should be doing.
  2. What consequences does this have for MultitaskMultivariateNormal? (as per @Balandat 's comment in #965 ).

rhaps0dy avatar Dec 02 '19 15:12 rhaps0dy

  1. Yes, this is important. We'll routinely want to draw multiple samples from a mvn of given batch and event shape, so we pass in base samples with additional leading batch dimensions. It is crucial that this keeps working.

  2. event_shape for MultitaskMultivariateNormal has an explicit trailing output dimension (the number of tasks), while it does not for the standard MultivariateNormal. Not sure whether your changes will cause issues regarding that.

The thing to understand here is that while a regular mvn has n-dim mean vector and n x n-dim covariance matrix, a MT mvn with t tasks will have a n x t-dim mean vector and an nt x nt-dim covariance matrix.

Balandat avatar Dec 02 '19 16:12 Balandat

Hi!

Yes, this is important. We'll routinely want to draw multiple samples from a mvn of given batch and event shape, so we pass in base samples with additional leading batch dimensions. It is crucial that this keeps working.

I didn't explain myself well. None of the current functionality is broken, things are only fixed. I'm asking whether it is necessary to add sampling when e.g. the mean is [3, 4] but the covariance is [4, 4] instead of [3, 4, 4]. Sampling from such a MultivariateNormal is already impossible at the moment. With this pull request, it is still possible to sample a MVN with mean.shape=[3, 4] and covariance_matrix.shape=[3, 4, 4].

Re: Multitask MVN. Thank you very much for the explanation. I've taken a look, I don't think my code will cause issues with it. However, the issue I fixed here can be fixed there as well. I'll push an update soon.

rhaps0dy avatar Dec 18 '19 22:12 rhaps0dy

I've written a more extensive test and gotten it working. But it's many complicated moving parts for supporting a perhaps niche use case.

Perhaps we should just disallow all MultivariateNormals from having different (but broadcastable) shapes for loc and lazy_covariance_matrix ? (we should keep the test anyways I think)

rhaps0dy avatar Dec 19 '19 02:12 rhaps0dy

@Balandat You had asked I think for a chance to test some of these things downstream -- do you have any issues here? Just resurrecting this since either we should (1) keep the functionality as it is on master but raise an error if an MVN is created with nonidentical batch shapes, or (2) merge something like this to allow broadcasting between the two.

I am generally fine with either (1) or (2).

jacobrgardner avatar Feb 20 '20 16:02 jacobrgardner

Sorry haven’t been able to look into this more; let me try to run this over the weekend.

Balandat avatar Feb 20 '20 19:02 Balandat

OK so the failure mode I'm running into on this is the following:

mvn = MultivariateNormal(torch.rand(3), torch.eye(3))
mtmvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn])
mtmvn.rsample(torch.Size([4]))

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-52-043ef69b94bb> in <module>
      1 mvn = MultivariateNormal(torch.rand(3), torch.eye(3))
      2 mtmvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn])
----> 3 mtmvn.rsample(torch.Size([4]))

~/Code/gpytorch/gpytorch/distributions/multitask_multivariate_normal.py in rsample(self, sample_shape, base_samples)
    208             base_samples = base_samples.view(*sample_shape, *self.loc.shape)
    209 
--> 210         samples = super().rsample(sample_shape=sample_shape, base_samples=base_samples)
    211         if not self._interleaved:
    212             # flip shape of last two dimensions

~/Code/gpytorch/gpytorch/distributions/multivariate_normal.py in rsample(self, sample_shape, base_samples)
    150             # Get samples
    151             samples = covar_.zero_mean_mvn_samples(num_samples)
--> 152             res = samples.view(*shape) + self.loc
    153             return res
    154         else:

RuntimeError: shape '[4, 6]' is invalid for input of size 48

If you drop into this you see that covar has the correct shape [6, 6], but covar_ has shape [12, 12] which is why the sample reshaping fails.

Balandat avatar Feb 21 '20 05:02 Balandat