botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] `batch_cross_validation` should not always pass `train_Yvar` to the model

Open Balandat opened this issue 2 years ago • 1 comments

🐛 Bug

batch_cross_validation passes train_Yvar=None into classes that don't use the (e.g. SingleTaskGP).

Instead, it should only pass in train_Yvar if it is not None.

Additional context

Originally flagged in #1667

Balandat avatar Feb 10 '23 02:02 Balandat

The context on this issue is that the following code confusingly raises a warning due to no fault of the user:

from functools import partial

import torch
from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize

from gpytorch.mlls import ExactMarginalLogLikelihood


train_X = torch.rand(20, 1, dtype=torch.float64) * 2
train_Y = (torch.sin(6 * train_X) + 0.2 * torch.rand_like(train_X)).squeeze()

cv_folds = gen_loo_cv_folds(train_X=train_X, train_Y=train_Y)

batch_cross_validation(
    partial(
        SingleTaskGP,
        outcome_transform=Standardize(m=1, batch_shape=(20,)),
        input_transform=Normalize(d=1)
    ),
    ExactMarginalLogLikelihood,
    cv_folds,
)
UserWarning: Keyword arguments ['train_Yvar'] will be ignored because they are not allowed parameters. Allowed parameters are ['train_X', 'train_Y', 'likelihood', 'covar_module', 'mean_module', 'outcome_transform', 'input_transform'].

Other improvements to batch_cross_validation would be:

  1. Allow batch_cross_validation to pass transforms or other keyword arguments to the model initializer, so that no one needs to do anything weird with partial
  2. Make cross-validation play better with transforms -- writing the batch shape correctly is pretty non-intuitive

esantorella avatar May 05 '23 19:05 esantorella

This was fixed in https://github.com/pytorch/botorch/pull/2269

saitcakmak avatar Jul 24 '24 18:07 saitcakmak