botorch
botorch copied to clipboard
[Bug] `botorch.optim.utils.sample_all_priors` fails to sample priors that GPyTorch can sample
botorch.optim.utils.sample_all_priors is failing to sample priors that GPyTorch can sample with module.sample_from_prior. The reason is that the shape of the parameter is being passed to the sample method:
https://github.com/pytorch/botorch/blob/913aa0e510dde10568c2b4b911124cdd626f6905/botorch/optim/utils.py#L43
However, it seems that GPyTorch's priors do not expect to receive the event shape as a parameter.
To reproduce
import torch
import gpytorch
import botorch
class GP(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, kernel):
super().__init__(train_x, train_y, gpytorch.likelihoods.GaussianLikelihood())
self.kernel = kernel
train_x = torch.stack([torch.linspace(0,1,50), torch.linspace(-1,0,50)]).T
train_y = torch.sin(train_x*3.14159/2).sum(-1)
gp = GP(train_x, train_y, gpytorch.kernels.RBFKernel(ard_num_dims=2))
gp.kernel.register_prior(
'lengthscale_prior',
gpytorch.priors.NormalPrior(torch.as_tensor([[1.,1.]]), torch.as_tensor(1.)),
lambda module: torch.log(module.lengthscale),
lambda module, sample: module._set_lengthscale(torch.exp(sample)),
)
gp.kernel.sample_from_prior('lengthscale_prior') # Works
botorch.optim.utils.sample_all_priors(gp) # Throws exception
Stack trace/error message
RuntimeError Traceback (most recent call last)
~\Programas\miniconda\envs\mcmc\lib\site-packages\gpytorch\module.py in initialize(self, **kwargs)
104 try:
--> 105 self.__getattr__(name).data.copy_(val.expand_as(self.__getattr__(name)))
106 except RuntimeError:
RuntimeError: expand(torch.FloatTensor{[1, 2, 1, 2]}, size=[1, 2]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (4)
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
<ipython-input-9-4b2855578412> in <module>
26 gp.kernel.sample_from_prior('lengthscale_prior') # Works
27
---> 28 botorch.optim.utils.sample_all_priors(gp) # Throws exception
~\Programas\miniconda\envs\mcmc\lib\site-packages\botorch\optim\utils.py in sample_all_priors(model)
41 )
42 try:
---> 43 setting_closure(module, prior.sample(closure(module).shape))
44 except NotImplementedError:
45 warnings.warn(
<ipython-input-9-4b2855578412> in <lambda>(module, sample)
21 gpytorch.priors.NormalPrior(torch.as_tensor([[1.,1.]]), torch.as_tensor(1.)),
22 lambda module: torch.log(module.lengthscale),
---> 23 lambda module, sample: module._set_lengthscale(torch.exp(sample)),
24 )
25
~\Programas\miniconda\envs\mcmc\lib\site-packages\gpytorch\kernels\kernel.py in _set_lengthscale(self, value)
252 value = torch.as_tensor(value).to(self.raw_lengthscale)
253
--> 254 self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))
255
256 def local_load_samples(self, samples_dict, memo, prefix):
~\Programas\miniconda\envs\mcmc\lib\site-packages\gpytorch\module.py in initialize(self, **kwargs)
108 self.__getattr__(name).data = val
109 else:
--> 110 self.__getattr__(name).data.copy_(val.view_as(self.__getattr__(name)))
111
112 elif isinstance(val, float):
RuntimeError: shape '[1, 2]' is invalid for input of size 4
Expected Behavior
No errors thrown.
System information
- BoTorch Version: 0.4.0
- GPyTorch Version: 1.4.0
- PyTorch Version 1.8.1+cu101
- Windows 10
So the sampling will work if you instantiate your prior with scalar parameters, e.g. gpytorch.priors.NormalPrior(1, 1) instead of what you're doing. Fundamentally, this is related to a bug (ambiguity?) in gpytorch, see https://github.com/cornellius-gp/gpytorch/issues/1317 and https://github.com/cornellius-gp/gpytorch/issues/1318.
I will leave this open for now, hopefully this unblocks you. To solve this properly we'll need to fix the root cause in gpytorch.
I see, even though my prior might be equivalent to gpytorch.priors.NormalPrior(1, 1), in the actual code I'm using, the loc parameter of the prior is different for each ARD dimension. It really seems like the problem is due to ambiguity with Event shape/Batch shape/Sample shape on GPyTorch's side.
In case someone else is also having this issue, and all of their priors have their event shape equal to the hyperparameter's shape, this workaround seems to until this gets disambiguated in GPyTorch :)
def sample_all_priors(model: GPyTorchModel) -> None:
r"""Sample from hyperparameter priors (in-place).
Args:
model: A GPyTorchModel.
"""
for _, module, prior, closure, setting_closure in model.named_priors():
if setting_closure is None:
raise RuntimeError(
"Must provide inverse transform to be able to sample from prior."
)
try:
- setting_closure(module, prior.sample(closure(module).shape))
+ setting_closure(module, prior.sample())
except NotImplementedError:
warnings.warn(
f"`rsample` not implemented for {type(prior)}. Skipping.",
BotorchWarning,
)
Indeed - if you want priors with different (prior) parameters for different parameter dimensions, then this is the way to do it (for now, until we've found a proper solution on the gpytorch end).
It is worth noting that if you're using a 1d prior for an ARD lengthscale, e.g. like in the SingleTaskGP where the lengthscale prior is set to GammaPrior(3.0, 6.0), then
setting_closure(module, prior.sample()) will sample only one value and set it to all dimensions.
Instead, you should define the prior like GammaPrior(torch.tensor([3., 3. , 3.], type=torch.float), rate=torch.tensor([6., 6., 6.], dtype=torch.float)) (where the ARD dimension is 3 in this case), so that prior.sample() works as expected.
The example in the OP still produces the error in the OP as of today.
I'm running into this same issue. This issue occurs if you use fit_gpytorch_mll on a model that uses different priors parameters for different ARD dimensions (e.g. if you set lengthscale prior to a GammaPrior with two tensors as args, rather than two scalars). Because fit_gpytorch_mll only calls sample_all_priors occasionally, the issue is sporadic.
Note that the workaround above won't behave as expected in batch mode. For example, when doing cross-validation, it will sample the same set of lengthscales for every cv fold, rather than different ones. Here's an altered workaround that works for this case.
def sample_all_priors(model: GPyTorchModel) -> None:
r"""Sample from hyperparameter priors (in-place).
Args:
model: A GPyTorchModel.
"""
for _, module, prior, closure, setting_closure in model.named_priors():
if setting_closure is None:
raise RuntimeError(
"Must provide inverse transform to be able to sample from prior."
)
try:
- setting_closure(module, prior.sample(closure(module).shape))
+ setting_closure(module, prior.sample(closure(module).shape[:-1]))
except NotImplementedError:
warnings.warn(
f"`rsample` not implemented for {type(prior)}. Skipping.",
BotorchWarning,
)