botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] `sample_all_priors` doesn't work with `KroneckerMultiTaskGP`

Open esantorella opened this issue 2 years ago • 5 comments

🐛 Bug

Moving this from #1323 .

To reproduce

** Code snippet to reproduce **

import torch
from botorch import fit_gpytorch_model
from botorch.models.multitask import KroneckerMultiTaskGP
from botorch.models.transforms.outcome import Standardize
from botorch.utils.transforms import normalize
from botorch.fit import sample_all_priors
from gpytorch.mlls import ExactMarginalLogLikelihood

tkwargs = {
    "dtype": torch.double,
    "device": "cpu",
}


train_x = torch.rand(1, 3, **tkwargs)
train_obj = torch.rand(1, 2, **tkwargs)

train_x = normalize(
    train_x, torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]], **tkwargs)
)

model = KroneckerMultiTaskGP(train_x, train_obj, outcome_transform=Standardize(m=2))
mll = ExactMarginalLogLikelihood(model.likelihood, model)

fit_gpytorch_model(mll)
sample_all_priors(mll.model)

** Stack trace/error message **

  File "/Users/santorella/issue_repros/botorch_1323.py", line 26, in <module>
    sample_all_priors(mll.model)
  File "/Users/santorella/repos/botorch/botorch/optim/utils/model_utils.py", line 195, in sample_all_priors
    raise RuntimeError(
RuntimeError: Must provide inverse transform to be able to sample from prior.

Expected Behavior

This should work because fit_gpytorch_mll calls sample_all_priors if model fitting fails.

System information

Please complete the following information:

  • BoTorch Version: 0.8.6.dev9+g0ad879cc
  • GPyTorch Version: 1.11
  • PyTorch Version: 1.13.0
  • MacOS

Additional context

See comments on #1323 for more info.

From @saitcakmak : "Looks like both MultitaskGaussianLikelihood and IndexKernel are missing a setting_closure."

From @Balandat :

This is most likely the LKJCovariancePrior over the intra-task correlation matrix, defined by default here:

https://github.com/pytorch/botorch/blob/fd30429726e2067e097a0e57123cad335782de46/botorch/models/multitask.py#L476

If you trace this down this is registered here: https://github.com/cornellius-gp/gpytorch/blob/d171863c50ab16b5bfb7035e579dcbe53169e703/gpytorch/kernels/index_kernel.py#L71

Basically this would need a setting_closure. In this case we're passing a covariance matrix Sigma in, so what we'd have to do here is define the closure to take in Sigma, factor it into a correlation matrix C and the variances var, perform a root decomposition of C and then set the covar_factor and var attributes of the IndexKernel.

esantorella avatar Jun 05 '23 19:06 esantorella

Hi! I'm a newcomer, could I please work on a pr for this? @esantorella

vmallela0 avatar Jun 14 '23 23:06 vmallela0

Hi! I'm a newcomer, could I please work on a pr for this? @esantorella

Of course!

esantorella avatar Jun 18 '23 21:06 esantorella

@esantorella, I am a bit confused as to where this setting_closure would go. Would it be in the Gpytorch index kernel or in the botorch multitask file?

vmallela0 avatar Jul 16 '23 15:07 vmallela0

@vmallela0, the setting closure would need to passed to register_prior (in the link in the comment above: here https://github.com/cornellius-gp/gpytorch/blob/d171863c50ab16b5bfb7035e579dcbe53169e703/gpytorch/kernels/index_kernel.py#L71) when the prior is "registered" on the gpytorch Module. Here are is where register_prior is defined: https://github.com/cornellius-gp/gpytorch/blob/d171863c50ab16b5bfb7035e579dcbe53169e703/gpytorch/module.py#L202

You can check out how register_prior is used on the base Kernel for registering a lengthscale prior as an example: https://github.com/cornellius-gp/gpytorch/blob/d171863c50ab16b5bfb7035e579dcbe53169e703/gpytorch/kernels/kernel.py#L174-L176

sdaulton avatar Jul 17 '23 14:07 sdaulton

Does anyone have a code snippet showing how the issue with LKJCovariance prior can be fixed?

EDIT: Example for closure of LKJCovariancePrior on IndexKernel here: https://github.com/pytorch/botorch/discussions/2656#discussioncomment-11785034

Hrovatin avatar Dec 20 '24 15:12 Hrovatin