botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] SingleTaskVariationalGP raises a warning when using input_transform

Open crasanders opened this issue 2 years ago • 4 comments

🐛 Bug

Calling SingleTaskVariationalGP.posterior with an input_transform raises a warning, whereas the equivalent call with SingleTaskGP does not. I'm not sure if input_transform works correctly with SingleTaskVariationalGP or if I can safely interpret the resulting posterior. It also seems a little odd to me that this would be a warning and not an exception.

To reproduce

** Code snippet to reproduce **

import torch
from botorch.models import SingleTaskGP, SingleTaskVariationalGP
from botorch.models.transforms import Normalize

X = torch.rand((20, 1))
y = torch.sin(X)

model = SingleTaskGP(train_X=X, train_Y=y, input_transform=Normalize(1))
post = model.posterior(X) # No warning

model = SingleTaskVariationalGP(train_X=X, train_Y=y, input_transform=Normalize(1))
post = model.posterior(X) # Warning

** Stack trace/error message **

RuntimeWarning: Could not update `train_inputs` with transformed inputs since SingleTaskVariationalGP does not have a `train_inputs` attribute. Make sure that the `input_transform` is applied to both the train inputs and test inputs.```

System information

botorch version = 0.8.5 gpytorch version = 1.10 torch version = 1.13.1

crasanders avatar May 10 '23 22:05 crasanders

Thanks for reporting this -- that is a concerning warning! I believe the warning is erroneous and the input transforms are being applied appropriately. Here's an example:

import torch
from botorch.models import SingleTaskGP, SingleTaskVariationalGP
from botorch.models.transforms import Normalize
from matplotlib import pyplot as plt
from gpytorch.mlls import VariationalELBO, ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll

train_X = torch.linspace(1, 3, 10, dtype=torch.double)[:, None]
y = -3 * train_X + 5
test_X = torch.linspace(1, 5, 10, dtype=torch.double)[:, None]

model = SingleTaskGP(train_X=train_X, train_Y=y, input_transform=Normalize(1))
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
post = model.posterior(test_X)

model = SingleTaskVariationalGP(train_X=train_X, train_Y=y, input_transform=Normalize(1))
mll = VariationalELBO(
    model.likelihood, model.model, num_data=train_X.shape[-2]
)
fit_gpytorch_mll(mll)
post_var = model.posterior(test_X) # Warning

fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)
axes[0].scatter(train_X, y, label="train data")
axes[0].plot(test_X, post.mean.detach().numpy(), label="posterior mean")
axes[0].legend()
axes[0].set_title("SingleTaskGP")
axes[1].scatter(train_X, y, label="train data")
axes[1].plot(test_X, post_var.mean.detach().numpy(), label="posterior mean")

axes[1].legend()
axes[1].set_title("VariationalGP")
for ax in axes:
    ax.set_xlabel("X")
axes[0].set_ylabel("y")

image

Variational GPs deal with input transforms differently than most models. When posterior is called, the variational GP tries to apply the usual input transform logic that is used by other models, fails, warns, and then does its own input transform logic correctly, here. I'll put in a PR so that this warning doesn't happen.

esantorella avatar May 11 '23 15:05 esantorella

cc @saitcakmak this is relevant for the proposed transforms refactor in https://github.com/cornellius-gp/gpytorch/pull/2114

Balandat avatar May 11 '23 15:05 Balandat

Great, glad it's a false alarm. Thanks for the quick reply.

crasanders avatar May 11 '23 15:05 crasanders

I vaguely recall some issue with inducing points sometimes getting transformed and sometimes not with ApproximateGP. Looking at my old notes, I found this:

It does not play well with ApproximateGP, since inducing points sometimes get transformed (in train mode) and sometimes don’t (in posterior). This leads to bugs as we train them under one setup and evaluate them under another.

So, there might be some truth to the warning here. The proper solution would be to push through https://github.com/pytorch/botorch/pull/1372 and fix this for good. I'll leave this open just in case.

saitcakmak avatar May 12 '23 17:05 saitcakmak