botorch
botorch copied to clipboard
[Bug] load_state_dict doesn't invalidate cached transformed inputs
🐛 Bug
botorch models do not properly support model.load_state_dict when using input_transforms with trained parameters. After performing model.load_state_dict, the model continues using cached transformed inputs that were computed with the previous parameters.
gpytorch doesn't have this bug in its caching; it intentionally clears all caches whenever loading the state dict.
Workaround: call model.train() before calling model.load_state_dict().
To reproduce
import copy
import botorch
import gpytorch
import torch
train_X = torch.tensor([[0., 0.],
[0.1, 0.1],
[0.5, 0.5]])
train_Y = torch.tensor([[0.],
[0.5],
[1.0]])
test_X = torch.tensor([[0.3, 0.3]])
model = botorch.models.SingleTaskGP(
train_X, train_Y,
# This is one example input transform that stores trained parameters in the
# state dict
input_transform=botorch.models.transforms.Warp(indices=[0, 1]))
state_dict = copy.deepcopy(model.state_dict())
# Check initial behavior
model.eval()
print(f"Before: {model(test_X).mean.item()}")
# Train model, adjusting the Warp parameters and caching transformed inputs
botorch.fit_gpytorch_mll(
gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
)
# Workaround: uncomment following line
# model.train()
# Revert to original parameters
model.load_state_dict(state_dict)
# Verify that output matches original output
model.eval()
print(f"After: {model(test_X).mean.item()}")
Actual output
Before: 0.2642167806625366
After: 0.21983212232589722
Expected output
Before: 0.2642167806625366
After: 0.2642167806625366
System information
BoTorch 0.7.2 GPyTorch 1.9.0 PyTorch 1.13.0 MacOS 13.0
This is just a product of how these transforms are implemented. Their attributes not cached the way GPyTorch caches the train_train_covar etc, but are buffers (or parameters in the case of Warp) that are learnable (or trainable). They will be included in the output of model.state_dict(), and likewise should be included in the model.load_state_dict(state_dict) call if you want them to get updated.
For the particular example of Warp, it is a trainable transform. It has parameters that are trained alongside other model hyper parameters. So, invalidating these on load_state_dict would not lead to any meaningful outcomes when you call the model (unless you train the model afterwards).
It also seems like the differences you see in your example are not due to some caching on the side of Warp or any buffers. When you call model.load_state_dict(state_dict) you reset the parameters of Warp to the originals (can be verified by checking state_dict). ~~The thing that changes on train call is GPyTorch deletes the prediction_strategy. So, the caches that stick around are actually the caches on model.prediction_strategy.~~
Nevermind, I was printing the prediction_strategy in the wrong spot.
Ok, here's what's happening. After the model training, an mll.eval() call triggers this bit of code that updates model.train_inputs with the transformed inputs. When you reload the state dict, you reset the input transform but not the transformed inputs - thus the odd behavior. The model.train() call reverts the train inputs to the originals, fixing this mismatch. This behavior will be cleaned up as part of #1372, which I will get back to in the next couple weeks.
Another thing to note is currently in eval model, for the input transforms to be applied, you should call the model through model.posterior. Otherwise, the input transforms will only be applied when the model is in train mode, leading to buggy behavior. This is also being cleaned up as part of #1372.
Cool, I'm looking forward to watching that PR. Maybe this is obvious at this point, but a quick fix to the current codebase would be to add to botorch.models.Model:
def _load_from_state_dict(self, *args, **kwargs):
self._revert_to_original_inputs()
super()._load_from_state_dict(*args, **kwargs)
(Here I'm just copying the approach used in gpytorch.)