botorch
botorch copied to clipboard
[Bug] `batch_cross_validation` should not always pass `train_Yvar` to the model
🐛 Bug
batch_cross_validation passes train_Yvar=None into classes that don't use the (e.g. SingleTaskGP).
Instead, it should only pass in train_Yvar if it is not None.
Additional context
Originally flagged in #1667
The context on this issue is that the following code confusingly raises a warning due to no fault of the user:
from functools import partial
import torch
from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
train_X = torch.rand(20, 1, dtype=torch.float64) * 2
train_Y = (torch.sin(6 * train_X) + 0.2 * torch.rand_like(train_X)).squeeze()
cv_folds = gen_loo_cv_folds(train_X=train_X, train_Y=train_Y)
batch_cross_validation(
partial(
SingleTaskGP,
outcome_transform=Standardize(m=1, batch_shape=(20,)),
input_transform=Normalize(d=1)
),
ExactMarginalLogLikelihood,
cv_folds,
)
UserWarning: Keyword arguments ['train_Yvar'] will be ignored because they are not allowed parameters. Allowed parameters are ['train_X', 'train_Y', 'likelihood', 'covar_module', 'mean_module', 'outcome_transform', 'input_transform'].
Other improvements to batch_cross_validation would be:
- Allow
batch_cross_validationto pass transforms or other keyword arguments to the model initializer, so that no one needs to do anything weird withpartial - Make cross-validation play better with transforms -- writing the batch shape correctly is pretty non-intuitive
This was fixed in https://github.com/pytorch/botorch/pull/2269