[Bug] Unable to export Multi-Output SingleTaskGP to Torchscript
🐛 Bug
Exporting SingleTask GPs (probably applies to other GPs as well) trained on Multi-Output data to torchscript fails.
To reproduce
import numpy as np
import torch
from torch import Tensor
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms import Standardize, Normalize
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.settings import trace_mode
X = torch.linspace(0, 1, 100).view(-1, 1)
y1 = torch.sin(2 * np.pi * X) + torch.randn_like(X) * 0.2
y2 = torch.cos(2 * np.pi * X) + torch.randn_like(X) * 0.2
Y = torch.cat([y1, y2], dim=1)
input_transform = Normalize(d=1)
outcome_transform = Standardize(m=2)
gp = SingleTaskGP(
X,
Y,
outcome_transform=outcome_transform,
input_transform=input_transform,
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)
class MeanVarModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
# get the model posterior
posterior = self.model.posterior(x, observation_noise=True)
mean = posterior.mean.detach()
std = posterior.variance.sqrt().detach()
return mean, std
X_test = torch.rand(10).view(-1, 1)
wrapped_model = MeanVarModelWrapper(gp)
with torch.no_grad(), trace_mode():
wrapped_model(X_test) # Compute caches
traced_model = torch.jit.trace(wrapped_model, X_test)
** Stack trace/error message **
{
"name": "RuntimeError",
"message": "mean shape torch.Size([10, 2]) is incompatible with covariance shape torch.Size([160, 160])",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[4], line 6
4 with torch.no_grad(), trace_mode():
5 wrapped_model(X_test) # Compute caches
----> 6 traced_model = torch.jit.trace(wrapped_model, X_test)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:1000, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
993 from torch._utils_internal import (
994 check_if_torch_exportable,
995 log_torch_jit_trace_exportability,
996 log_torchscript_usage,
997 )
999 log_torchscript_usage(\"trace\")
-> 1000 traced_func = _trace_impl(
1001 func,
1002 example_inputs,
1003 optimize,
1004 check_trace,
1005 check_inputs,
1006 check_tolerance,
1007 strict,
1008 _force_outplace,
1009 _module_class,
1010 _compilation_unit,
1011 example_kwarg_inputs,
1012 _store_inputs,
1013 )
1015 if check_if_torch_exportable():
1016 from torch._export.converter import TS2EPConverter
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:695, in _trace_impl(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
693 else:
694 raise RuntimeError(\"example_kwarg_inputs should be a dict\")
--> 695 return trace_module(
696 func,
697 {\"forward\": example_inputs},
698 None,
699 check_trace,
700 wrap_check_inputs(check_inputs),
701 check_tolerance,
702 strict,
703 _force_outplace,
704 _module_class,
705 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
706 _store_inputs=_store_inputs,
707 )
708 if (
709 hasattr(func, \"__self__\")
710 and isinstance(func.__self__, torch.nn.Module)
711 and func.__name__ == \"forward\"
712 ):
713 if example_inputs is None:
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:1275, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1273 else:
1274 example_inputs = make_tuple(example_inputs)
-> 1275 module._c._create_method_from_trace(
1276 method_name,
1277 func,
1278 example_inputs,
1279 var_lookup_fn,
1280 strict,
1281 _force_outplace,
1282 argument_names,
1283 _store_inputs,
1284 )
1286 check_trace_method = module._c._get_method(method_name)
1288 # Check the trace against new traces created from user-specified inputs
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1543, in Module._slow_forward(self, *input, **kwargs)
1541 recording_scopes = False
1542 try:
-> 1543 result = self.forward(*input, **kwargs)
1544 finally:
1545 if recording_scopes:
Cell In[3], line 8, in MeanVarModelWrapper.forward(self, x)
6 def forward(self, x):
7 # get the model posterior
----> 8 posterior = self.model.posterior(x, observation_noise=True)
9 mean = posterior.mean.detach()
10 std = posterior.variance.sqrt().detach()
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/models/gpytorch.py:459, in BatchedMultiOutputGPyTorchModel.posterior(self, X, output_indices, observation_noise, posterior_transform)
451 output_indices = output_indices or range(self._num_outputs)
452 mvns = [
453 MultivariateNormal(
454 mean_x.select(dim=output_dim_idx, index=t),
(...)
457 for t in output_indices
458 ]
--> 459 mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
460 # mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn)
462 posterior = GPyTorchPosterior(distribution=mvn)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/distributions/multitask_multivariate_normal.py:199, in MultitaskMultivariateNormal.from_independent_mvns(cls, mvns)
193 covar_blocks_lazy = CatLinearOperator(
194 *[mvn.lazy_covariance_matrix.unsqueeze(0) for mvn in mvns],
195 dim=0,
196 output_device=mean.device,
197 )
198 covar_lazy = BlockDiagLinearOperator(covar_blocks_lazy, block_dim=0)
--> 199 return cls(mean=mean, covariance_matrix=covar_lazy, interleaved=False)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/pyro/distributions/distribution.py:26, in DistributionMeta.__call__(cls, *args, **kwargs)
24 if result is not None:
25 return result
---> 26 return super().__call__(*args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/distributions/multitask_multivariate_normal.py:71, in MultitaskMultivariateNormal.__init__(self, mean, covariance_matrix, validate_args, interleaved)
65 mean = mean.expand(
66 *batch_shape,
67 mean.size(-2),
68 covariance_matrix.size(-2) // mean.size(-2),
69 )
70 else:
---> 71 raise RuntimeError(
72 f\"mean shape {mean.shape} is incompatible with covariance shape {covariance_matrix.shape}\"
73 )
74 else:
75 mean = mean.expand(*batch_shape, *mean.shape[-2:])
RuntimeError: mean shape torch.Size([10, 2]) is incompatible with covariance shape torch.Size([160, 160])"
}
Expected Behavior
Should have converted the model predictor into a torchscript module.
System information
Please complete the following information:
- BoTorch Version: 0.12.0
- GPyTorch Version: 1.13
- PyTorch Version: 2.4.1
- macOS Sonoma 14.5
Additional context
This is a bug in the posterior method of BatchedMultiOutputGPyTorchModel class. The generation of the list of MultivariateNormal distributions for every output is a bit complicated and the internal construction of a MultiTaskMultivariateNormal from these distributions fails when we try to evaluate the lazy_covariance_matrix in the trace_mode. Instead of doing this jugglery, we could simply use from_batch_mvn functionality and specify the task_dim parameter to directly create the required MultiTaskMultivariateNormal posterior distribution when in trace_mode. I was able to fix the error specified above by doing the same.
A potential fix is to route the code inside BatchedMultiOutputGPyTorchModel's posterior method in different ways.
When the trace_mode is on it could be just
if self._num_outputs > 1:
mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=0)
otherwise it would just do what exists now.
This will just deal with the evaluation of the lazy_covariance_matrix in trace_mode and fix this error and not break the tests.
A potential solution has been introduced in #2592.
Hmm interesting. At least in the past there must have been some reason to not just use MultitaskMultivariateNormal.from_batch_mvn() here. I fail to recall why exactly without digging into this more - possibly b/c the BlockInterleavedLinearOperator used here may result in issues (possibly performance related, possibly related to incompatibilities of that operator with other operations downstream)?
I have put in a PR that provides a temporary fix which at least enables exporting the model to torchscript. This uses the from_batch_mvn operation only when the posterior method is called in the trace_mode. Please let me know if there is a better fix for this issue.