botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] `botorch.optim.utils.sample_all_priors` fails to sample priors that GPyTorch can sample

Open spectraldani opened this issue 4 years ago • 7 comments

botorch.optim.utils.sample_all_priors is failing to sample priors that GPyTorch can sample with module.sample_from_prior. The reason is that the shape of the parameter is being passed to the sample method: https://github.com/pytorch/botorch/blob/913aa0e510dde10568c2b4b911124cdd626f6905/botorch/optim/utils.py#L43

However, it seems that GPyTorch's priors do not expect to receive the event shape as a parameter.

To reproduce

import torch
import gpytorch
import botorch

class GP(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, kernel):
        super().__init__(train_x, train_y, gpytorch.likelihoods.GaussianLikelihood())
        self.kernel = kernel

train_x = torch.stack([torch.linspace(0,1,50), torch.linspace(-1,0,50)]).T
train_y = torch.sin(train_x*3.14159/2).sum(-1)

gp = GP(train_x, train_y, gpytorch.kernels.RBFKernel(ard_num_dims=2))

gp.kernel.register_prior(
    'lengthscale_prior',
    gpytorch.priors.NormalPrior(torch.as_tensor([[1.,1.]]), torch.as_tensor(1.)),
    lambda module: torch.log(module.lengthscale),
    lambda module, sample: module._set_lengthscale(torch.exp(sample)),
)

gp.kernel.sample_from_prior('lengthscale_prior') # Works

botorch.optim.utils.sample_all_priors(gp) # Throws exception

Stack trace/error message

RuntimeError                              Traceback (most recent call last)
~\Programas\miniconda\envs\mcmc\lib\site-packages\gpytorch\module.py in initialize(self, **kwargs)
    104                 try:
--> 105                     self.__getattr__(name).data.copy_(val.expand_as(self.__getattr__(name)))
    106                 except RuntimeError:

RuntimeError: expand(torch.FloatTensor{[1, 2, 1, 2]}, size=[1, 2]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (4)

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-9-4b2855578412> in <module>
     26 gp.kernel.sample_from_prior('lengthscale_prior') # Works
     27 
---> 28 botorch.optim.utils.sample_all_priors(gp) # Throws exception

~\Programas\miniconda\envs\mcmc\lib\site-packages\botorch\optim\utils.py in sample_all_priors(model)
     41             )
     42         try:
---> 43             setting_closure(module, prior.sample(closure(module).shape))
     44         except NotImplementedError:
     45             warnings.warn(

<ipython-input-9-4b2855578412> in <lambda>(module, sample)
     21     gpytorch.priors.NormalPrior(torch.as_tensor([[1.,1.]]), torch.as_tensor(1.)),
     22     lambda module: torch.log(module.lengthscale),
---> 23     lambda module, sample: module._set_lengthscale(torch.exp(sample)),
     24 )
     25 

~\Programas\miniconda\envs\mcmc\lib\site-packages\gpytorch\kernels\kernel.py in _set_lengthscale(self, value)
    252             value = torch.as_tensor(value).to(self.raw_lengthscale)
    253 
--> 254         self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))
    255 
    256     def local_load_samples(self, samples_dict, memo, prefix):

~\Programas\miniconda\envs\mcmc\lib\site-packages\gpytorch\module.py in initialize(self, **kwargs)
    108                         self.__getattr__(name).data = val
    109                     else:
--> 110                         self.__getattr__(name).data.copy_(val.view_as(self.__getattr__(name)))
    111 
    112             elif isinstance(val, float):

RuntimeError: shape '[1, 2]' is invalid for input of size 4

Expected Behavior

No errors thrown.

System information

  • BoTorch Version: 0.4.0
  • GPyTorch Version: 1.4.0
  • PyTorch Version 1.8.1+cu101
  • Windows 10

spectraldani avatar Apr 29 '21 18:04 spectraldani

So the sampling will work if you instantiate your prior with scalar parameters, e.g. gpytorch.priors.NormalPrior(1, 1) instead of what you're doing. Fundamentally, this is related to a bug (ambiguity?) in gpytorch, see https://github.com/cornellius-gp/gpytorch/issues/1317 and https://github.com/cornellius-gp/gpytorch/issues/1318.

I will leave this open for now, hopefully this unblocks you. To solve this properly we'll need to fix the root cause in gpytorch.

Balandat avatar Apr 30 '21 00:04 Balandat

I see, even though my prior might be equivalent to gpytorch.priors.NormalPrior(1, 1), in the actual code I'm using, the loc parameter of the prior is different for each ARD dimension. It really seems like the problem is due to ambiguity with Event shape/Batch shape/Sample shape on GPyTorch's side.

In case someone else is also having this issue, and all of their priors have their event shape equal to the hyperparameter's shape, this workaround seems to until this gets disambiguated in GPyTorch :)

def sample_all_priors(model: GPyTorchModel) -> None:
    r"""Sample from hyperparameter priors (in-place).
    Args:
        model: A GPyTorchModel.
    """
    for _, module, prior, closure, setting_closure in model.named_priors():
        if setting_closure is None:
            raise RuntimeError(
                "Must provide inverse transform to be able to sample from prior."
            )
        try:
-            setting_closure(module, prior.sample(closure(module).shape))
+            setting_closure(module, prior.sample())
        except NotImplementedError:
            warnings.warn(
                f"`rsample` not implemented for {type(prior)}. Skipping.",
                BotorchWarning,
            )

spectraldani avatar Apr 30 '21 13:04 spectraldani

Indeed - if you want priors with different (prior) parameters for different parameter dimensions, then this is the way to do it (for now, until we've found a proper solution on the gpytorch end).

Balandat avatar Apr 30 '21 14:04 Balandat

It is worth noting that if you're using a 1d prior for an ARD lengthscale, e.g. like in the SingleTaskGP where the lengthscale prior is set to GammaPrior(3.0, 6.0), then setting_closure(module, prior.sample()) will sample only one value and set it to all dimensions.

Instead, you should define the prior like GammaPrior(torch.tensor([3., 3. , 3.], type=torch.float), rate=torch.tensor([6., 6., 6.], dtype=torch.float)) (where the ARD dimension is 3 in this case), so that prior.sample() works as expected.

georgedeath avatar May 06 '21 09:05 georgedeath

The example in the OP still produces the error in the OP as of today.

esantorella avatar Jan 30 '23 20:01 esantorella

I'm running into this same issue. This issue occurs if you use fit_gpytorch_mll on a model that uses different priors parameters for different ARD dimensions (e.g. if you set lengthscale prior to a GammaPrior with two tensors as args, rather than two scalars). Because fit_gpytorch_mll only calls sample_all_priors occasionally, the issue is sporadic.

Note that the workaround above won't behave as expected in batch mode. For example, when doing cross-validation, it will sample the same set of lengthscales for every cv fold, rather than different ones. Here's an altered workaround that works for this case.

def sample_all_priors(model: GPyTorchModel) -> None:
    r"""Sample from hyperparameter priors (in-place).
    Args:
        model: A GPyTorchModel.
    """
    for _, module, prior, closure, setting_closure in model.named_priors():
        if setting_closure is None:
            raise RuntimeError(
                "Must provide inverse transform to be able to sample from prior."
            )
        try:
-            setting_closure(module, prior.sample(closure(module).shape))
+            setting_closure(module, prior.sample(closure(module).shape[:-1]))
        except NotImplementedError:
            warnings.warn(
                f"`rsample` not implemented for {type(prior)}. Skipping.",
                BotorchWarning,
            )

mrcslws avatar Mar 10 '23 16:03 mrcslws