gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Bug] LKJCholeskyFactorPrior fails on GPU

Open wjmaddox opened this issue 2 years ago • 5 comments

🐛 Bug

LKJCholeskyFactorPrior fails on the GPU. This also takes out GPU support for botorch's KroneckerMultiTaskGP @Balandat

To reproduce

** Code snippet to reproduce **

import torch
from gpytorch.priors import LKJCholeskyFactorPrior

a = torch.randn(5, 5)
mat = a @ a.t() + torch.diag(torch.rand(5))
inv_sqrt = torch.diag(mat.diag().reciprocal())
corrmat = inv_sqrt @ mat @ inv_sqrt

prior = LKJCholeskyFactorPrior(5, 0.5)
prior.log_prob(corrmat)

prior = prior.to(torch.device("cuda:0"))
prior.log_prob(corrmat.cuda())

## botorch error
train_x = torch.randn(30, 1).cuda()
train_y = torch.randn(30, 3).cuda()
model = KroneckerMultiTaskGP(train_x, train_y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)

mll(model(*model.train_inputs), model.train_targets) # same error

** Stack trace/error message **

File ~/gpytorch/gpytorch/priors/prior.py:27, in Prior.log_prob(self, x)
     22 def log_prob(self, x):
     23     r"""
     24     :return: log-probability of the parameter value under the prior
     25     :rtype: torch.Tensor
     26     """
---> 27     return super(Prior, self).log_prob(self.transform(x))

File ~/miniconda3/lib/python3.9/site-packages/torch/distributions/lkj_cholesky.py:117, in LKJCholesky.log_prob(self, value)
    115 order = torch.arange(2, self.dim + 1)
    116 order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
--> 117 unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
    118 # Compute normalization constant (page 1999 of [1])
    119 dm1 = self.dim - 1

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Expected Behavior

LKJCholesky log prob shouldn't error.

System information

Please complete the following information:

  • pytorch 1.11

Additional context

Might be fixed in pytorch nightly but thought I'd point it out:

https://github.com/pytorch/pytorch/issues/58774

I could also put up a PR here copying that log prob and just enforcing the order tensor to have the proper device.

wjmaddox avatar Apr 19 '22 13:04 wjmaddox

Hmm, if this is breaking also on some pytorch release versions (e.g. 1.11) then we probably want to patch it in our code.

But I am actually not sure how you'd elegantly do this since the issue is happening inside LKJCholesky and doesn't seem easily circumventable (unless you do the computation on the CPU altogether or temporarily change the default torch device). Maybe @neerajprad has some thoughts?

Balandat avatar Apr 19 '22 13:04 Balandat

Easiest fix is probably just to perform that computation on the CPU. There's probably minimal use-cases where the inter-task covariance matrix is actually large enough to really be slowed down by having the log probability be performed on the CPU.

wjmaddox avatar Apr 19 '22 14:04 wjmaddox

I also verified that this issue is somehow fixed on pytorch nightly.

wjmaddox avatar Apr 19 '22 14:04 wjmaddox

This is a bug which was fixed by @dme65 last month, that's why the nightly seems to work fine. I'll also comment on the issue.

neerajprad avatar Apr 19 '22 14:04 neerajprad

Easiest fix is probably just to perform that computation on the CPU.

Unfortunately, I don't have a good solution in the interim. A hacky one that we have previously used in Pyro is to monkey-patch until the solution makes it to release. See https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/torch_patch.py#L65 on one way to patch the log prob which can be removed on PyTorch's next release.

neerajprad avatar Apr 19 '22 14:04 neerajprad