[Bug] OrthogonalAdditiveKernel doesn't work with input transforms because they generate x values outside the unit hypercube
🐛 Bug
OrthogonalAdditiveKernel will error here if provided x values outside the unit hypercube, [0, 1]^d. Unfortunately, when combining this kernel with basic BoTorch functionality, it is hard to avoid passing such values. For example, if the search space is [0, 1], a model is trained on points ranging from 0.25 to 0.75, and a Normalize input transform is used, then 0 and 1 will transform to -1 and 2 and lie outside the hypercube.
To reproduce
** Code snippet to reproduce **
from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel
from botorch.models.gp_regression import SingleTaskGP, get_matern_kernel_with_gamma_prior
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models.transforms.outcome import Standardize
from botorch.models.transforms.input import Normalize
import torch
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
train_X = torch.tensor([[0.3], [0.7]], dtype=torch.float64)
train_Y = torch.tensor([[0.3], [0.7]], dtype=torch.float64)
kernel = OrthogonalAdditiveKernel(
base_kernel=get_matern_kernel_with_gamma_prior(
ard_num_dims=None,
),
dim=1,
dtype=torch.double,
)
model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=1),
covar_module=kernel
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_mll(mll)
model.posterior(train_X) # works
# errors
model.posterior(torch.tensor([[0.2], [0.8]]), dtype=torch.float64)
** Stack trace/error message **
Traceback (most recent call last):
File "/Users/lizs/oak_issue.py", line 36, in <module>
model.posterior(torch.tensor([[0.2], [0.8]]), dtype=torch.float64)
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/gpytorch.py", line 388, in posterior
mvn = self(X)
^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 333, in __call__
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 281, in exact_prediction
test_covar = joint_covar[..., self.num_train :, :].to_dense()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
return self.evaluate_kernel().to_dense()
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
output = method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
res = self.kernel(
^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/kernels/kernel.py", line 530, in __call__
super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/module.py", line 31, in __call__
outputs = self.forward(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 163, in forward
K_ortho = self._orthogonal_base_kernels(x1, x2) # batch_shape x d x n1 x n2
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 202, in _orthogonal_base_kernels
_check_hypercube(x1, "x1")
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 270, in _check_hypercube
raise ValueError(name + " is not in hypercube [0, 1]^d.")
ValueError: x1 is not in hypercube [0, 1]^d.
acqf = qLogNoisyExpectedImprovement(
model,
X_baseline=train_X,
)
optimize_acqf(
acqf,
bounds=torch.tensor([[0.0], [1.0]]),
q=1,
num_restarts=16,
raw_samples=32,
)
Traceback (most recent call last):
File "/Users/lizs/oak_issue.py", line 43, in <module>
optimize_acqf(
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 563, in optimize_acqf
return _optimize_acqf(opt_acqf_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 584, in _optimize_acqf
return _optimize_acqf_batch(opt_inputs=opt_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 274, in _optimize_acqf_batch
batch_initial_conditions = opt_inputs.get_ic_generator()(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/initializers.py", line 417, in gen_batch_initial_conditions
Y_rnd_curr = acq_function(
^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/utils/transforms.py", line 305, in decorated
return method(cls, X, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/utils/transforms.py", line 259, in decorated
output = method(acqf, X, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/monte_carlo.py", line 274, in forward
non_reduced_acqval = self._non_reduced_forward(X=X)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/monte_carlo.py", line 287, in _non_reduced_forward
samples, obj = self._get_samples_and_objectives(X)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/logei.py", line 465, in _get_samples_and_objectives
posterior = self.model.posterior(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/gpytorch.py", line 388, in posterior
mvn = self(X)
^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 333, in __call__
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 281, in exact_prediction
test_covar = joint_covar[..., self.num_train :, :].to_dense()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
return self.evaluate_kernel().to_dense()
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
output = method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
res = self.kernel(
^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/kernels/kernel.py", line 530, in __call__
super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/module.py", line 31, in __call__
outputs = self.forward(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 163, in forward
K_ortho = self._orthogonal_base_kernels(x1, x2) # batch_shape x d x n1 x n2
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 202, in _orthogonal_base_kernels
_check_hypercube(x1, "x1")
File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 270, in _check_hypercube
raise ValueError(name + " is not in hypercube [0, 1]^d.")
ValueError: x1 is not in hypercube [0, 1]^d.
Expected Behavior
This should not error. Currently, the only way to use the OAK is to manually normalize the search space (rather than the training data) to [0, 1], which is not documented or well-supported.
System information
Please complete the following information:
- BoTorch Version: 0.10.0
- GPyTorch Version: 1.11
- PyTorch Version: 2.2.2
- Computer OS: OS X
@SebastianAment is the requirement that all inputs are contained in the unit cube critical for this kernel?
Thanks for raising this. I added this check to ensure that the search space bounds are passed to Normalize, otherwise the orthogonality condition can only be guaranteed on the training set. In the example above, passing Normalize(d=1, bounds=bounds) would work. We can add this to the error message.
@SebastianAment is the requirement that all inputs are contained in the unit cube critical for this kernel?
In principle we could also open the kernel up to be evaluated outside of the orthogonality domain, but I think it's better to error out in this case, at least by default, as orthogonality is the defining property that users would expect from the kernel.
cc @hvarfner
Closing because this is expected behavior. We can re-open whenever we want to re-evaluate the limitation.