fix the shape bug in `LKJCovariancePrior`
This PR would fix #2685. There are two main changes in this PR.
SmoothedBoxPriorhas incorrect event shapes. As a result, the unit tests didn't catch the bug in #2685.- A univariate distribution should have an empty event shape. Thus, we should expect
SmoothedBoxPrior(a=1, b=2, sigma=0.1)has an event shapetorch.Size([]). However, univariate smoothed box priors currently has an event shapetorch.Size([1]). This PR fixes it.
- A univariate distribution should have an empty event shape. Thus, we should expect
- Unsqueeze the marginal standard deviations in
LKJCovariancePrior.sampleso that the matrix shapes work out for all priors, e.g.,GammaPrior.
Two questions / concerns regarding this PR.
First, I am not sure if the changed event shape would affect downstream packages. It does seems to be the right thing to do. All unit tests have passed. Nevertheless, it could be a breaking change.
Second, the diagonal of the covariance matrices generated by the LKFCovariancePrior are always the same. The diagonal entries are the same even before this PR.
sd_prior = gpytorch.priors.SmoothedBoxPrior(0.1, 2.0)
prior = gpytorch.priors.LKJCovariancePrior(n=2, eta=1.0, sd_prior=sd_prior)
samples = prior.sample(torch.Size([2]))
print(samples)
# output:
tensor([[[0.3535, 0.0609],
[0.0609, 0.3535]],
[[0.4324, 0.0467],
[0.0467, 0.4324]]])
This is due this line. We sample a single marginal standard deviation for each batch. Thus, within each batch, the marginal standard deviations are the same across the dimensions. https://github.com/cornellius-gp/gpytorch/blob/60be9533fb571482a4ea270921cb635d71da56f8/gpytorch/priors/lkj_prior.py#L111
I wonder if this is intentional. I thought it makes more sense to sample from self.sd_prior independently for each diagonal entry otherwise LKJCovariancePrior seems to be very restrictive.
This change make sense to me. It's technically BC breaking but it also fixes some consistency issue that's kind of a pain and has been open for a long time (see my comments here).
Second, the diagonal of the covariance matrices generated by the LKFCovariancePrior are always the same
I don't think this is intentional, this seems like a bug. This bit in the docstring suggest that the prior is the same for each element on the diagonal, but that's of course the prior and not the realization when sampling - Let's fix it! https://github.com/cornellius-gp/gpytorch/blob/60be9533fb571482a4ea270921cb635d71da56f8/gpytorch/priors/lkj_prior.py#L77-L78
It seems that we need to support both homoscedastic and heteroscedastic marginal standard deviations. Both use cases are tested in TestLKJCovariancePrior. This is a bit tricky because torch distributions and LKJCovariancePrior have different interpretations on the batch shape and event shape.
Example 1
Let's say we want a batch of LKJ covariance priors on 3 x 3 covariance matrices and heteroscedastic Gamma priors on the diagonal entries, i.e., each diagonal entry in each batch gets a different prior. Currently this is not possible.
import torch
from gpytorch.priors import GammaPrior, LKJCovariancePrior, LKJPrior
# A batched LKJ prior over 3 x 3 correlation matrices
lkj_prior = LKJPrior(n=3, eta=torch.ones(2))
print(lkj_prior.batch_shape) # torch.Size([2])
print(lkj_prior.event_shape) # torch.Size([3, 3])
# Want heteroskedastic marginal standard deviations for each diagonal entry
sd_prior = GammaPrior(concentration=torch.rand(2, 3), rate=6.0)
print(sd_prior.batch_shape) # torch.Size([2, 3]), the batch shape is different from the LKJ prior above!
print(sd_prior.event_shape) # torch.Size([])
# Error. Cannot create the LKJ covariance prior because the batch shapes do not align
LKJCovariancePrior(
n=3,
eta=torch.ones(2),
sd_prior=sd_prior,
)
The Gamma prior always has an empty event shape. The shapes of its inputs only affect the batch shape, but not the event shape. This makes sense (for PyTorch) because Gamma distributions are univariate. But this causes headache for us.
Ideally, we want the Gamma prior to have batch_shape = (2,) and event_shape = (3,) such that its batch shape is broadcastable with the batch shape of the LKJ prior and its event shape is broadcastable with the correlation matrices.
Example 2
Again, we want a batch of LKJ covariance priors on 3 x 3 covariance matrices. This time the batch size is 3, which is coincidentally equal to the matrix dimension. The following code runs fine. But its semantics is ambiguous and has two different interpretations:
- Three different homoscedastic marginal standard deviation priors;
- A single heteroscedastic marginal standard deviation prior broadcasted across the batch dimension of the LKJ prior.
import torch
from gpytorch.priors import GammaPrior, LKJCovariancePrior, LKJPrior
# A batched LKJ prior over 3 x 3 correlation matrices
lkj_prior = LKJPrior(n=3, eta=torch.ones(3))
print(lkj_prior.batch_shape) # torch.Size([3])
print(lkj_prior.event_shape) # torch.Size([3, 3])
# A batch of Gamma priors
sd_prior = GammaPrior(concentration=torch.rand(3), rate=6.0)
print(sd_prior.batch_shape) # torch.Size([3])
print(sd_prior.event_shape) # torch.Size([])
# This code works. But this use case is ambiguous and has two plausible interpretations.
LKJCovariancePrior(
n=3,
eta=torch.ones(3),
sd_prior=sd_prior,
)
Again, the root cause is that the Gamma distributions always have empty event shapes. So it depends on the implementation to choose the interpretation. Perhaps the interpretation 1 is more plausible?
Proposal 1
I am going to assume GammaPrior.batch_shape has at most one more dimension compared to LKJCovariancePrior.batch_shape and attempt to broadcast starting from the left. There are three potential cases:
GammaPrior.batch_shape == LKJCovariancePrior.batch_shapeGammaPrior.batch_shape == (*LKJCovariancePrior.batch_shape, 1)GammaPrior.batch_shape == (*LKJCovariancePrior.batch_shape, n)
Case 1 & 2 are heteroskedastic and will invoke broadcasting in the last dimension. Case 3 is heteroscedastic.
Pro. This implementation should work for most priors, including all univariate priors like Gamma, Normal, and SmoothedBoxPrior.
Con.
- Broadcasting starting from the left is against the usual broadcasting semantics.
- Might cause troubles if the standard deviation prior has an non-empty event shape, e.g.,
MultivariateNormalPrior. The docstring says the standard deviation prior is a scalar prior, though. So this does not seem to be supported in the first place.
Proposal 2
Let the users specify the event shape properly by torch.distributions.Independent. For example, the issue in Example 1 would be solved nicely by reinterpreting the last dimension as the event shape.
sd_prior = GammaPrior(concentration=torch.rand(2, 3), rate=6.0)
print(sd_prior.batch_shape) # torch.Size([2, 3])
print(sd_prior.event_shape) # torch.Size([])
sd_prior = torch.distributions.Independent(sd_prior, reinterpreted_batch_ndims=1)
print(sd_prior.batch_shape) # torch.Size([2])
print(sd_prior.event_shape) # torch.Size([3])
Pro. This seems to be the cleanest approach (from the implementation perspective).
Con. The users have to do extra work.
I am leaning towards Proposal 1 in this PR, but I am also open to thoughts on 2.