botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] `KroneckerMultiTaskGP` incompatible with `batch_cross_validation` or batched models

Open slishak-PX opened this issue 1 year ago • 2 comments

🐛 Bug

When trying to use batch_cross_validation to fit a KroneckerMultiTaskGP, a batch shape error occurs.

The code snippet below uses a similar example to https://botorch.org/tutorials/batch_mode_cross_validation.

The following snippet also causes the same error, suggesting that the problem is that KroneckerMultiTaskGP does not correctly support batching.

gp = KroneckerMultiTaskGP(cv_folds.train_X, cv_folds.train_Y)
gp.posterior(cv_folds.test_X)

To reproduce

** Code snippet to reproduce **

import math
import torch
from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds
from botorch.models import SingleTaskGP, KroneckerMultiTaskGP
from botorch.models.transforms.input import Normalize
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

device = torch.device("cuda:0")
dtype = torch.float64
torch.manual_seed(3)

sigma = math.sqrt(0.2)
coefs = torch.linspace(1, 4, 4, dtype=dtype, device=device).view(1, -1)
train_X = torch.linspace(0, 1, 20, dtype=dtype, device=device).view(-1, 1)
train_Y_noiseless = torch.sin(train_X * coefs * math.pi)
train_Y = train_Y_noiseless + sigma * torch.randn_like(train_Y_noiseless)

cv_folds = gen_loo_cv_folds(train_X, train_Y)

# If this is False, a SingleTaskGP is used
multi_task = True

cv_results = batch_cross_validation(
    model_cls=KroneckerMultiTaskGP if multi_task else SingleTaskGP,
    mll_cls=ExactMarginalLogLikelihood,
    cv_folds=cv_folds,
    model_init_kwargs={
        "input_transform": Normalize(d=train_X.shape[-1]),
    },
)

** Stack trace/error message **

File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/kernels/multitask_kernel.py:53, in MultitaskKernel.forward(self, x1, x2, diag, last_dim_is_batch, **params)
     51     covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)
     52 covar_x = to_linear_operator(self.data_covar_module.forward(x1, x2, **params))
---> 53 res = KroneckerProductLinearOperator(covar_x, covar_i)
     54 return res.diagonal(dim1=-1, dim2=-2) if diag else res

File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File /opt/conda/envs/python3/lib/python3.10/site-packages/linear_operator/operators/kronecker_product_linear_operator.py:81, in KroneckerProductLinearOperator.__init__(self, *linear_ops)
     79     batch_broadcast_shape = torch.broadcast_shapes(*(linear_op.batch_shape for linear_op in linear_ops))
     80 except RuntimeError:
---> 81     raise RuntimeError(
     82         "Batch shapes of LinearOperators "
     83         f"({', '.join([str(tuple(linear_op.shape)) for linear_op in linear_ops])}) "
     84         "are incompatible for a Kronecker product."
     85     )
     87 if len(batch_broadcast_shape):  # Otherwise all linear_ops are non-batch, and we don't need to expand
     88     # NOTE: we must explicitly call requires_grad on each of these arguments
     89     # for the automatic _bilinear_derivative to work in torch.autograd.Functions
     90     linear_ops = tuple(
     91         linear_op._expand_batch(batch_broadcast_shape).requires_grad_(linear_op.requires_grad)
     92         for linear_op in linear_ops
     93     )

RuntimeError: Batch shapes of LinearOperators ((20, 19, 19), (400, 4, 4)) are incompatible for a Kronecker product.

Expected Behavior

System information

Please complete the following information:

  • BoTorch 0.12.0 (also occurs with 0.11.3, although this code requires 0.12 due to the lack of Standardize)
  • GPyTorch 1.13 (also occurs with 1.12)
  • PyTorch 2.0.1+cu117
  • Linux

Additional context

Add any other context about the problem here.

slishak-PX avatar Sep 26 '24 14:09 slishak-PX