Attempting to fix `MultivariateNormal.rsample`
When mean and covariance have different (but compatible) batch dimensions, the previous code fails.
There are two things that I don't know how we should handle:
- Is it important that
rsampleworks overbase_samplesthat can be broadcasted but don't have the same shape? I have decided to not add this feature because it wasn't there, I don't need it and it would require somewhat larger implementation changes. Arguably it's also out of scope of whatrsampleshould be doing. - What consequences does this have for
MultitaskMultivariateNormal? (as per @Balandat 's comment in #965 ).
-
Yes, this is important. We'll routinely want to draw multiple samples from a mvn of given batch and event shape, so we pass in base samples with additional leading batch dimensions. It is crucial that this keeps working.
-
event_shapeforMultitaskMultivariateNormalhas an explicit trailing output dimension (the number of tasks), while it does not for the standardMultivariateNormal. Not sure whether your changes will cause issues regarding that.
The thing to understand here is that while a regular mvn has n-dim mean vector and n x n-dim covariance matrix, a MT mvn with t tasks will have a n x t-dim mean vector and an nt x nt-dim covariance matrix.
Hi!
Yes, this is important. We'll routinely want to draw multiple samples from a mvn of given batch and event shape, so we pass in base samples with additional leading batch dimensions. It is crucial that this keeps working.
I didn't explain myself well. None of the current functionality is broken, things are only fixed. I'm asking whether it is necessary to add sampling when e.g. the mean is [3, 4] but the covariance is [4, 4] instead of [3, 4, 4]. Sampling from such a MultivariateNormal is already impossible at the moment. With this pull request, it is still possible to sample a MVN with mean.shape=[3, 4] and covariance_matrix.shape=[3, 4, 4].
Re: Multitask MVN. Thank you very much for the explanation. I've taken a look, I don't think my code will cause issues with it. However, the issue I fixed here can be fixed there as well. I'll push an update soon.
I've written a more extensive test and gotten it working. But it's many complicated moving parts for supporting a perhaps niche use case.
Perhaps we should just disallow all MultivariateNormals from having different (but broadcastable) shapes for loc and lazy_covariance_matrix ? (we should keep the test anyways I think)
@Balandat You had asked I think for a chance to test some of these things downstream -- do you have any issues here? Just resurrecting this since either we should (1) keep the functionality as it is on master but raise an error if an MVN is created with nonidentical batch shapes, or (2) merge something like this to allow broadcasting between the two.
I am generally fine with either (1) or (2).
Sorry haven’t been able to look into this more; let me try to run this over the weekend.
OK so the failure mode I'm running into on this is the following:
mvn = MultivariateNormal(torch.rand(3), torch.eye(3))
mtmvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn])
mtmvn.rsample(torch.Size([4]))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-52-043ef69b94bb> in <module>
1 mvn = MultivariateNormal(torch.rand(3), torch.eye(3))
2 mtmvn = MultitaskMultivariateNormal.from_independent_mvns([mvn, mvn])
----> 3 mtmvn.rsample(torch.Size([4]))
~/Code/gpytorch/gpytorch/distributions/multitask_multivariate_normal.py in rsample(self, sample_shape, base_samples)
208 base_samples = base_samples.view(*sample_shape, *self.loc.shape)
209
--> 210 samples = super().rsample(sample_shape=sample_shape, base_samples=base_samples)
211 if not self._interleaved:
212 # flip shape of last two dimensions
~/Code/gpytorch/gpytorch/distributions/multivariate_normal.py in rsample(self, sample_shape, base_samples)
150 # Get samples
151 samples = covar_.zero_mean_mvn_samples(num_samples)
--> 152 res = samples.view(*shape) + self.loc
153 return res
154 else:
RuntimeError: shape '[4, 6]' is invalid for input of size 48
If you drop into this you see that covar has the correct shape [6, 6], but covar_ has shape [12, 12] which is why the sample reshaping fails.