gpytorch
gpytorch copied to clipboard
[Bug] LKJCholeskyFactorPrior fails on GPU
🐛 Bug
LKJCholeskyFactorPrior fails on the GPU. This also takes out GPU support for botorch's KroneckerMultiTaskGP
@Balandat
To reproduce
** Code snippet to reproduce **
import torch
from gpytorch.priors import LKJCholeskyFactorPrior
a = torch.randn(5, 5)
mat = a @ a.t() + torch.diag(torch.rand(5))
inv_sqrt = torch.diag(mat.diag().reciprocal())
corrmat = inv_sqrt @ mat @ inv_sqrt
prior = LKJCholeskyFactorPrior(5, 0.5)
prior.log_prob(corrmat)
prior = prior.to(torch.device("cuda:0"))
prior.log_prob(corrmat.cuda())
## botorch error
train_x = torch.randn(30, 1).cuda()
train_y = torch.randn(30, 3).cuda()
model = KroneckerMultiTaskGP(train_x, train_y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
mll(model(*model.train_inputs), model.train_targets) # same error
** Stack trace/error message **
File ~/gpytorch/gpytorch/priors/prior.py:27, in Prior.log_prob(self, x)
22 def log_prob(self, x):
23 r"""
24 :return: log-probability of the parameter value under the prior
25 :rtype: torch.Tensor
26 """
---> 27 return super(Prior, self).log_prob(self.transform(x))
File ~/miniconda3/lib/python3.9/site-packages/torch/distributions/lkj_cholesky.py:117, in LKJCholesky.log_prob(self, value)
115 order = torch.arange(2, self.dim + 1)
116 order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
--> 117 unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
118 # Compute normalization constant (page 1999 of [1])
119 dm1 = self.dim - 1
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Expected Behavior
LKJCholesky log prob shouldn't error.
System information
Please complete the following information:
- pytorch 1.11
Additional context
Might be fixed in pytorch nightly but thought I'd point it out:
https://github.com/pytorch/pytorch/issues/58774
I could also put up a PR here copying that log prob and just enforcing the order tensor to have the proper device.
Hmm, if this is breaking also on some pytorch release versions (e.g. 1.11) then we probably want to patch it in our code.
But I am actually not sure how you'd elegantly do this since the issue is happening inside LKJCholesky
and doesn't seem easily circumventable (unless you do the computation on the CPU altogether or temporarily change the default torch device). Maybe @neerajprad has some thoughts?
Easiest fix is probably just to perform that computation on the CPU. There's probably minimal use-cases where the inter-task covariance matrix is actually large enough to really be slowed down by having the log probability be performed on the CPU.
I also verified that this issue is somehow fixed on pytorch nightly.
This is a bug which was fixed by @dme65 last month, that's why the nightly seems to work fine. I'll also comment on the issue.
Easiest fix is probably just to perform that computation on the CPU.
Unfortunately, I don't have a good solution in the interim. A hacky one that we have previously used in Pyro is to monkey-patch until the solution makes it to release. See https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/torch_patch.py#L65 on one way to patch the log prob which can be removed on PyTorch's next release.